From 72e20b652b8d697e5dc0605db284e3b637f11bac Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 07:45:10 -0800 Subject: [PATCH 01/63] feat: Adding GlobalInstructionPlugin PiperOrigin-RevId: 862232259 --- .../adk/plugins/GlobalInstructionPlugin.java | 118 +++++++++++++ .../plugins/GlobalInstructionPluginTest.java | 155 ++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java create mode 100644 core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java diff --git a/core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java b/core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java new file mode 100644 index 000000000..1773bf701 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java @@ -0,0 +1,118 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.utils.InstructionUtils; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Plugin that provides global instructions functionality at the App level. + * + *

Global instructions are applied to all agents in the application, providing a consistent way + * to set application-wide instructions, identity, or personality. Global instructions can be + * provided as a static string, or as a function that resolves the instruction based on the {@link + * CallbackContext}. + * + *

The plugin operates through the before_model_callback, allowing it to modify LLM requests + * before they are sent to the model by prepending the global instruction to any existing system + * instructions provided by the agent. + */ +public class GlobalInstructionPlugin extends BasePlugin { + + private final Function> instructionProvider; + + private static Function> createInstructionProvider( + String globalInstruction) { + return callbackContext -> { + if (globalInstruction == null) { + return Maybe.empty(); + } + return InstructionUtils.injectSessionState( + callbackContext.invocationContext(), globalInstruction) + .toMaybe(); + }; + } + + public GlobalInstructionPlugin(String globalInstruction) { + this(globalInstruction, "global_instruction"); + } + + public GlobalInstructionPlugin(String globalInstruction, String name) { + this(createInstructionProvider(globalInstruction), name); + } + + public GlobalInstructionPlugin(Function> instructionProvider) { + this(instructionProvider, "global_instruction"); + } + + public GlobalInstructionPlugin( + Function> instructionProvider, String name) { + super(name); + this.instructionProvider = instructionProvider; + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return instructionProvider + .apply(callbackContext) + .filter(instruction -> !instruction.isEmpty()) + .flatMap( + instruction -> { + // Get mutable config, or create one if it doesn't exist. + GenerateContentConfig config = + llmRequest.config().orElseGet(GenerateContentConfig.builder()::build); + + // Get existing system instruction parts, if any. + Optional systemInstruction = config.systemInstruction(); + List existingParts = + systemInstruction.flatMap(Content::parts).orElse(ImmutableList.of()); + + // Prepend the global instruction to the existing system instruction parts. + // If there are existing instructions, add two newlines between the global + // instruction and the existing instructions. + ImmutableList.Builder newPartsBuilder = ImmutableList.builder(); + if (existingParts.isEmpty()) { + newPartsBuilder.add(Part.fromText(instruction)); + } else { + newPartsBuilder.add(Part.fromText(instruction + "\n\n")); + newPartsBuilder.addAll(existingParts); + } + + // Build the new system instruction content. + Content.Builder newSystemInstructionBuilder = Content.builder(); + systemInstruction.flatMap(Content::role).ifPresent(newSystemInstructionBuilder::role); + newSystemInstructionBuilder.parts(newPartsBuilder.build()); + + // Update llmRequest with new config. + llmRequest.config( + config.toBuilder() + .systemInstruction(newSystemInstructionBuilder.build()) + .build()); + return Maybe.empty(); + }); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java new file mode 100644 index 000000000..345314256 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.adk.sessions.State; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class GlobalInstructionPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + @Mock private CallbackContext mockCallbackContext; + @Mock private InvocationContext mockInvocationContext; + private final State state = new State(new ConcurrentHashMap<>()); + private final Session session = Session.builder("session_id").state(state).build(); + @Mock private BaseArtifactService mockArtifactService; + + @Before + public void setUp() { + state.clear(); + when(mockCallbackContext.invocationId()).thenReturn("invocation_id"); + when(mockCallbackContext.agentName()).thenReturn("agent_name"); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.artifactService()).thenReturn(mockArtifactService); + } + + @Test + public void beforeModelCallback_noExistingInstruction() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("global instruction"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("global instruction"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_withExistingInstruction() { + LlmRequest.Builder llmRequestBuilder = + LlmRequest.builder() + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder().parts(Part.fromText("existing instruction")).build()) + .build()); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("global instruction"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(2); + assertThat(parts.get(0).text()).hasValue("global instruction\n\n"); + assertThat(parts.get(1).text()).hasValue("existing instruction"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_withInstructionProvider() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = + new GlobalInstructionPlugin(unusedContext -> Maybe.just("instruction from provider")); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("instruction from provider"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_withStringInstruction_injectsState() { + state.put("name", "Alice"); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("Hello {name}"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("Hello Alice"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_nullInstruction() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin((String) null); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + assertThat(llmRequestBuilder.build().config()).isEmpty(); + } + + @Test + public void beforeModelCallback_emptyInstruction() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin(""); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + assertThat(llmRequestBuilder.build().config()).isEmpty(); + } + + @Test + public void beforeModelCallback_withExistingInstructionAndRole_preservesRole() { + LlmRequest.Builder llmRequestBuilder = + LlmRequest.builder() + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder() + .parts(Part.fromText("existing instruction")) + .role("system") + .build()) + .build()); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("global instruction"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(2); + assertThat(parts.get(0).text()).hasValue("global instruction\n\n"); + assertThat(parts.get(1).text()).hasValue("existing instruction"); + assertThat(systemInstruction.role()).hasValue("system"); + } +} From 32a6b625d96e5658be77d5017f10014d8d4036c1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 07:45:37 -0800 Subject: [PATCH 02/63] feat: EventAction.stateDelta() now has a remove by key variant This is going to be a useful feature for rewind PiperOrigin-RevId: 862232441 --- .../com/google/adk/events/EventActions.java | 11 +++ .../adk/sessions/InMemorySessionService.java | 70 ++++++++++------ .../adk/sessions/SessionJsonConverter.java | 51 ++++++++---- .../java/com/google/adk/tools/AgentTool.java | 23 +++++- .../google/adk/events/EventActionsTest.java | 10 +++ .../sessions/InMemorySessionServiceTest.java | 79 ++++++++++++++++--- .../google/adk/sessions/MockApiAnswer.java | 39 +++++++++ .../sessions/SessionJsonConverterTest.java | 47 +++++++++++ .../sessions/VertexAiSessionServiceTest.java | 28 +++++++ 9 files changed, 304 insertions(+), 54 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 63909ee1a..493fa4b27 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.adk.agents.BaseAgentState; +import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Part; import java.util.Objects; @@ -98,10 +99,20 @@ public ConcurrentMap stateDelta() { return stateDelta; } + @Deprecated // Use stateDelta(), addState() and removeStateByKey() instead. public void setStateDelta(ConcurrentMap stateDelta) { this.stateDelta = stateDelta; } + /** + * Removes a key from the state delta. + * + * @param key The key to remove. + */ + public void removeStateByKey(String key) { + stateDelta.put(key, State.REMOVED); + } + @JsonProperty("artifactDelta") public ConcurrentMap artifactDelta() { return artifactDelta; diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 80c277fce..b658f6767 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -96,8 +96,8 @@ public Single createSession( .build(); sessions - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(resolvedSessionId, newSession); // Create a mutable copy for the return value @@ -116,8 +116,8 @@ public Maybe getSession( Session storedSession = sessions - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .get(sessionId); if (storedSession == null) { @@ -166,7 +166,7 @@ public Single listSessions(String appName, String userId) Objects.requireNonNull(userId, "userId cannot be null"); Map userSessionsMap = - sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId); + sessions.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()).get(userId); if (userSessionsMap == null || userSessionsMap.isEmpty()) { return Single.just(ListSessionsResponse.builder().build()); @@ -185,11 +185,12 @@ public Completable deleteSession(String appName, String userId, String sessionId Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - ConcurrentMap userSessionsMap = - sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId); - - if (userSessionsMap != null) { - userSessionsMap.remove(sessionId); + ConcurrentMap> appSessionsMap = sessions.get(appName); + if (appSessionsMap != null) { + ConcurrentMap userSessionsMap = appSessionsMap.get(userId); + if (userSessionsMap != null) { + userSessionsMap.remove(sessionId); + } } return Completable.complete(); } @@ -202,8 +203,8 @@ public Single listEvents(String appName, String userId, Stri Session storedSession = sessions - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .get(sessionId); if (storedSession == null) { @@ -236,17 +237,34 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - appState - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .put(appStateKey, value); + if (value == State.REMOVED) { + appState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .remove(appStateKey); + } else { + appState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .put(appStateKey, value); + } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - userState - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) - .put(userStateKey, value); - } else { - session.state().put(key, value); + if (value == State.REMOVED) { + userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + .remove(userStateKey); + } else { + userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + .put(userStateKey, value); + } + } else if (!key.startsWith(State.TEMP_PREFIX)) { + if (value == State.REMOVED) { + session.state().remove(key); + } else { + session.state().put(key, value); + } } }); } @@ -257,8 +275,8 @@ public Single appendEvent(Session session, Event event) { // --- Update the session stored in this service --- sessions - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(sessionId, session); mergeWithGlobalState(appName, userId, session); @@ -307,12 +325,12 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess // Merge App State directly into the session's state map appState - .getOrDefault(appName, new ConcurrentHashMap()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); userState - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); return session; diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 5dbbe76c7..d1a661a91 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -91,7 +91,7 @@ static String convertEventToJson(Event event) { if (event.actions() != null) { Map actionsJson = new HashMap<>(); actionsJson.put("skipSummarization", event.actions().skipSummarization()); - actionsJson.put("stateDelta", event.actions().stateDelta()); + actionsJson.put("stateDelta", stateDeltaToJson(event.actions().stateDelta())); actionsJson.put("artifactDelta", event.actions().artifactDelta()); actionsJson.put("transferAgent", event.actions().transferToAgent()); actionsJson.put("escalate", event.actions().escalate()); @@ -126,8 +126,7 @@ static String convertEventToJson(Event event) { * @return parsed {@link Content}, or {@code null} if conversion fails. */ @Nullable - // Safe because we check instanceof Map before casting. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting. private static Content convertMapToContent(Object rawContentValue) { if (rawContentValue == null) { return null; @@ -153,8 +152,7 @@ private static Content convertMapToContent(Object rawContentValue) { * * @return parsed {@link Event}. */ - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. static Event fromApiEvent(Map apiEvent) { EventActions.Builder eventActionsBuilder = EventActions.builder(); if (apiEvent.get("actions") != null) { @@ -162,10 +160,7 @@ static Event fromApiEvent(Map apiEvent) { if (actionsMap.get("skipSummarization") != null) { eventActionsBuilder.skipSummarization((Boolean) actionsMap.get("skipSummarization")); } - eventActionsBuilder.stateDelta( - actionsMap.get("stateDelta") != null - ? new ConcurrentHashMap<>((Map) actionsMap.get("stateDelta")) - : new ConcurrentHashMap<>()); + eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta"))); eventActionsBuilder.artifactDelta( actionsMap.get("artifactDelta") != null ? convertToArtifactDeltaMap(actionsMap.get("artifactDelta")) @@ -238,6 +233,32 @@ static Event fromApiEvent(Map apiEvent) { return event; } + @SuppressWarnings("unchecked") // stateDeltaFromMap is a Map from JSON. + private static ConcurrentMap stateDeltaFromJson(Object stateDeltaFromMap) { + if (stateDeltaFromMap == null) { + return new ConcurrentHashMap<>(); + } + return ((Map) stateDeltaFromMap) + .entrySet().stream() + .collect( + ConcurrentHashMap::new, + (map, entry) -> + map.put( + entry.getKey(), + entry.getValue() == null ? State.REMOVED : entry.getValue()), + ConcurrentHashMap::putAll); + } + + private static Map stateDeltaToJson(Map stateDelta) { + return stateDelta.entrySet().stream() + .collect( + HashMap::new, + (map, entry) -> + map.put( + entry.getKey(), entry.getValue() == State.REMOVED ? null : entry.getValue()), + HashMap::putAll); + } + /** * Converts a timestamp from a Map or String into an {@link Instant}. * @@ -263,8 +284,7 @@ private static Instant convertToInstant(Object timestampObj) { * @param artifactDeltaObj The raw object from which to parse the artifact delta. * @return A {@link ConcurrentMap} representing the artifact delta. */ - // Safe because we check instanceof Map before casting. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting. private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { return new ConcurrentHashMap<>(); @@ -287,8 +307,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti * * @return thread-safe nested map. */ - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap> asConcurrentMapOfConcurrentMaps(Object value) { return ((Map>) value) @@ -299,8 +318,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti ConcurrentHashMap::putAll); } - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap asConcurrentMapOfAgentState(Object value) { return ((Map) value) .entrySet().stream() @@ -313,8 +331,7 @@ private static ConcurrentMap asConcurrentMapOfAgentState ConcurrentHashMap::putAll); } - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap asConcurrentMapOfToolConfirmations( Object value) { return ((Map) value) diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index a531361f2..1a8dbc527 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -28,6 +28,7 @@ import com.google.adk.events.Event; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; +import com.google.adk.sessions.State; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -154,7 +155,7 @@ public Single> runAsync(Map args, ToolContex if (lastEvent.actions() != null && lastEvent.actions().stateDelta() != null && !lastEvent.actions().stateDelta().isEmpty()) { - toolContext.state().putAll(lastEvent.actions().stateDelta()); + updateState(lastEvent.actions().stateDelta(), toolContext.state()); } if (outputText.isEmpty()) { @@ -174,4 +175,24 @@ public Single> runAsync(Map args, ToolContex } }); } + + /** + * Updates the given state map with the state delta. + * + *

If a value in the delta is {@link State#REMOVED}, the key is removed from the state map. + * Otherwise, the key-value pair is put into the state map. This method does not distinguish + * between session, app, and user state based on key prefixes. + * + * @param state The state map to update. + */ + private void updateState(Map stateDelta, Map state) { + stateDelta.forEach( + (key, value) -> { + if (value == State.REMOVED) { + state.remove(key); + } else { + state.put(key, value); + } + }); + } } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index a9e3693d5..18870ad44 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.Part; @@ -97,4 +98,13 @@ public void merge_mergesAllFields() { assertThat(merged.endInvocation()).hasValue(true); assertThat(merged.compaction()).hasValue(COMPACTION); } + + @Test + public void removeStateByKey_marksKeyAsRemoved() { + EventActions eventActions = new EventActions(); + eventActions.stateDelta().put("key1", "value1"); + eventActions.removeStateByKey("key1"); + + assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); + } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 4c35f5b90..97b182249 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -89,8 +89,9 @@ public void lifecycle_listSessions() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("_app_appKey", "appValue"); - stateDelta.put("_user_userKey", "userValue"); + stateDelta.put("app:appKey", "appValue"); + stateDelta.put("user:userKey", "userValue"); + stateDelta.put("temp:tempKey", "tempValue"); Event event = Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); @@ -105,8 +106,9 @@ public void lifecycle_listSessions() { assertThat(listedSession.id()).isEqualTo(session.id()); assertThat(listedSession.events()).isEmpty(); assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue"); - assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(listedSession.state()).containsEntry("app:appKey", "appValue"); + assertThat(listedSession.state()).containsEntry("user:userKey", "userValue"); + assertThat(listedSession.state()).doesNotContainKey("temp:tempKey"); } @Test @@ -134,8 +136,9 @@ public void appendEvent_updatesSessionState() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("_app_appKey", "appValue"); - stateDelta.put("_user_userKey", "userValue"); + stateDelta.put("app:appKey", "appValue"); + stateDelta.put("user:userKey", "userValue"); + stateDelta.put("temp:tempKey", "tempValue"); Event event = Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); @@ -145,8 +148,9 @@ public void appendEvent_updatesSessionState() { // After appendEvent, session state in memory should contain session-specific state from delta // and merged global state. assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).containsEntry("_app_appKey", "appValue"); - assertThat(session.state()).containsEntry("_user_userKey", "userValue"); + assertThat(session.state()).containsEntry("app:appKey", "appValue"); + assertThat(session.state()).containsEntry("user:userKey", "userValue"); + assertThat(session.state()).doesNotContainKey("temp:tempKey"); // getSession should return session with merged state. Session retrievedSession = @@ -154,7 +158,62 @@ public void appendEvent_updatesSessionState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue"); - assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSession.state()).containsEntry("app:appKey", "appValue"); + assertThat(retrievedSession.state()).containsEntry("user:userKey", "userValue"); + assertThat(retrievedSession.state()).doesNotContainKey("temp:tempKey"); + } + + @Test + public void appendEvent_removesState() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + stateDeltaAdd.put("sessionKey", "sessionValue"); + stateDeltaAdd.put("app:appKey", "appValue"); + stateDeltaAdd.put("user:userKey", "userValue"); + stateDeltaAdd.put("temp:tempKey", "tempValue"); + + Event eventAdd = + Event.builder().actions(EventActions.builder().stateDelta(stateDeltaAdd).build()).build(); + + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify state is added + Session retrievedSessionAdd = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("app:appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("user:userKey", "userValue"); + assertThat(retrievedSessionAdd.state()).doesNotContainKey("temp:tempKey"); + + // Prepare and append event to remove state + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED); + stateDeltaRemove.put("app:appKey", State.REMOVED); + stateDeltaRemove.put("user:userKey", State.REMOVED); + stateDeltaRemove.put("temp:tempKey", State.REMOVED); + + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + unused = sessionService.appendEvent(session, eventRemove).blockingGet(); + + // Verify state is removed + Session retrievedSessionRemove = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("app:appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("user:userKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } } diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 5e8f3d992..111b1dce3 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -8,6 +8,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -164,6 +165,18 @@ private ApiResponse handleAppendEvent(String path, InvocationOnMock invocation) eventsData.add(newEventData); eventMap.put(sessionId, mapper.writeValueAsString(eventsData)); + + // Apply stateDelta to session state + extractObjectMap(newEventData, "actions") + .flatMap(actions -> extractObjectMap(actions, "stateDelta")) + .ifPresent( + stateDelta -> { + try { + applyStateDelta(sessionId, stateDelta); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); } catch (Exception e) { throw new RuntimeException(e); } @@ -213,4 +226,30 @@ private ApiResponse handleDeleteSession(String path) { sessionMap.remove(sessionIdToDelete); return responseWithBody(""); } + + private void applyStateDelta(String sessionId, Map stateDelta) throws Exception { + String sessionDataString = sessionMap.get(sessionId); + if (sessionDataString == null) { + return; + } + Map sessionData = + mapper.readValue(sessionDataString, new TypeReference>() {}); + Map sessionState = + extractObjectMap(sessionData, "sessionState").map(HashMap::new).orElseGet(HashMap::new); + + for (Map.Entry entry : stateDelta.entrySet()) { + if (entry.getValue() == null) { + sessionState.remove(entry.getKey()); + } else { + sessionState.put(entry.getKey(), entry.getValue()); + } + } + sessionData.put("sessionState", sessionState); + sessionMap.put(sessionId, mapper.writeValueAsString(sessionData)); + } + + @SuppressWarnings("unchecked") // Safe because map values are Maps read from JSON. + private Optional> extractObjectMap(Map map, String key) { + return Optional.ofNullable((Map) map.get(key)); + } } diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index b77d6f267..827e810aa 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -221,4 +221,51 @@ public void fromApiEvent_missingMetadataFields_success() { assertThat(event.turnComplete().get()).isFalse(); assertThat(event.interrupted().get()).isFalse(); } + + @Test + public void convertEventToJson_withStateRemoved_success() throws JsonProcessingException { + EventActions actions = + EventActions.builder() + .stateDelta( + new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", State.REMOVED))) + .build(); + + Event event = + Event.builder() + .author("user") + .invocationId("inv-123") + .timestamp(Instant.parse("2023-01-01T00:00:00Z").toEpochMilli()) + .actions(actions) + .build(); + + String json = SessionJsonConverter.convertEventToJson(event); + JsonNode jsonNode = objectMapper.readTree(json); + + JsonNode actionsNode = jsonNode.get("actions"); + assertThat(actionsNode.get("stateDelta").get("key1").asText()).isEqualTo("value1"); + assertThat(actionsNode.get("stateDelta").get("key2").isNull()).isTrue(); + } + + @Test + public void fromApiEvent_withNullStateDeltaValue_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map stateDelta = new HashMap<>(); + stateDelta.put("key1", "value1"); + stateDelta.put("key2", null); + + Map actions = new HashMap<>(); + actions.put("stateDelta", stateDelta); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + EventActions eventActions = event.actions(); + assertThat(eventActions.stateDelta()).containsEntry("key1", "value1"); + assertThat(eventActions.stateDelta()).containsEntry("key2", State.REMOVED); + } } diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 775b465ff..36eab1d16 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -337,4 +337,32 @@ public void listEmptySession_success() { .events()) .isEmpty(); } + + @Test + public void appendEvent_withStateRemoved_updatesSessionState() { + String userId = "userB"; + ConcurrentMap initialState = + new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + Session session = + vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet(); + + ConcurrentMap stateDelta = + new ConcurrentHashMap<>(ImmutableMap.of("key2", State.REMOVED)); + Event event = + Event.builder() + .invocationId("456") + .author(userId) + .timestamp(Instant.parse("2024-12-12T12:12:12.123456Z").toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + var unused = vertexAiSessionService.appendEvent(session, event).blockingGet(); + + Session updatedSession = + vertexAiSessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + + assertThat(updatedSession.state()).containsExactly("key1", "value1"); + assertThat(updatedSession.state()).doesNotContainKey("key2"); + } } From 0c6e61cd1a424aaacc44d12b435510b322293b09 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 07:53:40 -0800 Subject: [PATCH 03/63] refactor: Using computeIfAbsent to not create unnecessary instances PiperOrigin-RevId: 862235526 --- .../artifacts/InMemoryArtifactService.java | 42 +++++++------------ .../codeexecutors/CodeExecutorContext.java | 20 ++++----- .../google/adk/flows/llmflows/Functions.java | 10 +---- .../adk/tutorials/LiveAudioSingleAgent.java | 32 +++++++------- 4 files changed, 43 insertions(+), 61 deletions(-) diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 890820196..27b85136d 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -48,11 +48,8 @@ public InMemoryArtifactService() { public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { List versions = - artifacts - .computeIfAbsent(appName, k -> new HashMap<>()) - .computeIfAbsent(userId, k -> new HashMap<>()) - .computeIfAbsent(sessionId, k -> new HashMap<>()) - .computeIfAbsent(filename, k -> new ArrayList<>()); + getArtifactsMap(appName, userId, sessionId) + .computeIfAbsent(filename, unused -> new ArrayList<>()); versions.add(artifact); return Single.just(versions.size() - 1); } @@ -66,11 +63,8 @@ public Single saveArtifact( public Maybe loadArtifact( String appName, String userId, String sessionId, String filename, Optional version) { List versions = - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .getOrDefault(filename, new ArrayList<>()); + getArtifactsMap(appName, userId, sessionId) + .computeIfAbsent(filename, unused -> new ArrayList<>()); if (versions.isEmpty()) { return Maybe.empty(); @@ -97,13 +91,7 @@ public Single listArtifactKeys( String appName, String userId, String sessionId) { return Single.just( ListArtifactsResponse.builder() - .filenames( - ImmutableList.copyOf( - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .keySet())) + .filenames(ImmutableList.copyOf(getArtifactsMap(appName, userId, sessionId).keySet())) .build()); } @@ -115,11 +103,7 @@ public Single listArtifactKeys( @Override public Completable deleteArtifact( String appName, String userId, String sessionId, String filename) { - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .remove(filename); + getArtifactsMap(appName, userId, sessionId).remove(filename); return Completable.complete(); } @@ -132,15 +116,19 @@ public Completable deleteArtifact( public Single> listVersions( String appName, String userId, String sessionId, String filename) { int size = - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .getOrDefault(filename, new ArrayList<>()) + getArtifactsMap(appName, userId, sessionId) + .computeIfAbsent(filename, unused -> new ArrayList<>()) .size(); if (size == 0) { return Single.just(ImmutableList.of()); } return Single.just(IntStream.range(0, size).boxed().collect(toImmutableList())); } + + private Map> getArtifactsMap(String appName, String userId, String sessionId) { + return artifacts + .computeIfAbsent(appName, unused -> new HashMap<>()) + .computeIfAbsent(userId, unused -> new HashMap<>()) + .computeIfAbsent(sessionId, unused -> new HashMap<>()); + } } diff --git a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java index bad83ebc0..a34102225 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java +++ b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java @@ -89,7 +89,8 @@ public void setExecutionId(String sessionId) { * @return A list of processed file names in the code executor context. */ public List getProcessedFileNames() { - return (List) this.context.getOrDefault(PROCESSED_FILE_NAMES_KEY, new ArrayList<>()); + return (List) + this.context.computeIfAbsent(PROCESSED_FILE_NAMES_KEY, unused -> new ArrayList<>()); } /** @@ -100,7 +101,7 @@ public List getProcessedFileNames() { public void addProcessedFileNames(List fileNames) { List processedFileNames = (List) - this.context.computeIfAbsent(PROCESSED_FILE_NAMES_KEY, k -> new ArrayList<>()); + this.context.computeIfAbsent(PROCESSED_FILE_NAMES_KEY, unused -> new ArrayList<>()); processedFileNames.addAll(fileNames); } @@ -126,7 +127,7 @@ public List getInputFiles() { public void addInputFiles(List inputFiles) { List> fileMaps = (List>) - this.sessionState.computeIfAbsent(INPUT_FILE_KEY, k -> new ArrayList<>()); + this.sessionState.computeIfAbsent(INPUT_FILE_KEY, unused -> new ArrayList<>()); for (File inputFile : inputFiles) { fileMaps.add( objectMapper.convertValue(inputFile, new TypeReference>() {})); @@ -166,7 +167,7 @@ public int getErrorCount(String invocationId) { public void incrementErrorCount(String invocationId) { Map errorCounts = (Map) - this.sessionState.computeIfAbsent(ERROR_COUNT_KEY, k -> new HashMap<>()); + this.sessionState.computeIfAbsent(ERROR_COUNT_KEY, unused -> new HashMap<>()); errorCounts.put(invocationId, getErrorCount(invocationId) + 1); } @@ -176,9 +177,6 @@ public void incrementErrorCount(String invocationId) { * @param invocationId The invocation ID to reset the error count for. */ public void resetErrorCount(String invocationId) { - if (!this.sessionState.containsKey(ERROR_COUNT_KEY)) { - return; - } Map errorCounts = (Map) this.sessionState.get(ERROR_COUNT_KEY); if (errorCounts != null) { @@ -198,9 +196,10 @@ public void updateCodeExecutionResult( String invocationId, String code, String resultStdout, String resultStderr) { Map>> codeExecutionResults = (Map>>) - this.sessionState.computeIfAbsent(CODE_EXECUTION_RESULTS_KEY, k -> new HashMap<>()); + this.sessionState.computeIfAbsent( + CODE_EXECUTION_RESULTS_KEY, unused -> new HashMap<>()); List> resultsForInvocation = - codeExecutionResults.computeIfAbsent(invocationId, k -> new ArrayList<>()); + codeExecutionResults.computeIfAbsent(invocationId, unused -> new ArrayList<>()); Map newResult = new HashMap<>(); newResult.put("code", code); newResult.put("result_stdout", resultStdout); @@ -210,6 +209,7 @@ public void updateCodeExecutionResult( } private Map getCodeExecutorContext(Map sessionState) { - return (Map) sessionState.computeIfAbsent(CONTEXT_KEY, k -> new HashMap<>()); + return (Map) + sessionState.computeIfAbsent(CONTEXT_KEY, unused -> new HashMap<>()); } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index ce7687c3d..3bb57faee 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -52,7 +52,6 @@ import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Function; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -332,7 +331,7 @@ private static Maybe> processFunctionLive( ActiveStreamingTool activeTool = invocationContext .activeStreamingTools() - .getOrDefault(tool.name(), new ActiveStreamingTool(subscription)); + .computeIfAbsent(tool.name(), unused -> new ActiveStreamingTool(subscription)); activeTool.task(subscription); invocationContext.activeStreamingTools().put(tool.name(), activeTool); @@ -607,12 +606,7 @@ private static Event buildResponseEvent( .invocationId(invocationContext.invocationId()) .author(invocationContext.agent().name()) .branch(invocationContext.branch()) - .content( - Optional.of( - Content.builder() - .role("user") - .parts(Collections.singletonList(partFunctionResponse)) - .build())) + .content(Content.builder().role("user").parts(partFunctionResponse).build()) .actions(toolContext.eventActions()) .build(); Tracing.traceToolResponse(invocationContext, event.id(), event); diff --git a/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java b/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java index 85e8914a7..a1342a936 100644 --- a/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java +++ b/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java @@ -30,13 +30,12 @@ public class LiveAudioSingleAgent { .model("gemini-2.0-flash-live-001") .description("A helpful weather assistant that provides weather information.") .instruction( - "You are a friendly weather assistant. When users ask about weather, " - + "you MUST call the getWeather tool with the location name. " - + "Extract the location from the user's question. " - + "ALWAYS use the getWeather tool to get accurate information - never make up weather data. " - + "After getting the tool result, provide a friendly and descriptive response. " - + "For general conversation or greetings, respond naturally and helpfully. " - + "Do NOT use code execution for anything.") + "You are a friendly weather assistant. When users ask about weather, you MUST call" + + " the getWeather tool with the location name. Extract the location from the" + + " user's question. ALWAYS use the getWeather tool to get accurate information -" + + " never make up weather data. After getting the tool result, provide a friendly" + + " and descriptive response. For general conversation or greetings, respond" + + " naturally and helpfully. Do NOT use code execution for anything.") .tools(FunctionTool.create(LiveAudioSingleAgent.class, "getWeather")) .build(); @@ -89,16 +88,17 @@ public static Map getWeather( String normalizedLocation = location.toLowerCase().trim(); - return weatherData.getOrDefault( + return weatherData.computeIfAbsent( normalizedLocation, - Map.of( - "status", - "error", - "report", - String.format( - "Weather information for '%s' is not available. Try New York, London, Tokyo, or" - + " Sydney.", - location))); + unused -> + Map.of( + "status", + "error", + "report", + String.format( + "Weather information for '%s' is not available. Try New York, London, Tokyo," + + " or Sydney.", + location))); } public static void main(String[] args) { From 8f7d7eac95cc606b5c5716612d0b08c41f951167 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 13:18:35 -0800 Subject: [PATCH 04/63] feat: Add event compaction config to InvocationContext This CL integrates event compaction configuration into the InvocationContext, allowing agent to access the configs during execution. PiperOrigin-RevId: 862366865 --- .../google/adk/agents/InvocationContext.java | 24 +++++++++++++++++++ .../java/com/google/adk/runner/Runner.java | 1 + 2 files changed, 25 insertions(+) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index ace00db4c..ed9b21062 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -25,6 +25,7 @@ import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.adk.summarizer.EventsCompactionConfig; import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; @@ -53,6 +54,7 @@ public class InvocationContext { private final Map agentStates; private final Map endOfAgents; private final ResumabilityConfig resumabilityConfig; + @Nullable private final EventsCompactionConfig eventsCompactionConfig; private final InvocationCostManager invocationCostManager; private Optional branch; @@ -76,6 +78,7 @@ protected InvocationContext(Builder builder) { this.agentStates = builder.agentStates; this.endOfAgents = builder.endOfAgents; this.resumabilityConfig = builder.resumabilityConfig; + this.eventsCompactionConfig = builder.eventsCompactionConfig; this.invocationCostManager = builder.invocationCostManager; } @@ -356,6 +359,11 @@ public boolean isResumable() { return resumabilityConfig.isResumable(); } + /** Returns the events compaction configuration for the current agent run. */ + public Optional eventsCompactionConfig() { + return Optional.ofNullable(eventsCompactionConfig); + } + /** Returns whether to pause the invocation right after this [event]. */ public boolean shouldPauseInvocation(Event event) { if (!isResumable()) { @@ -427,6 +435,7 @@ private Builder(InvocationContext context) { this.agentStates = new ConcurrentHashMap<>(context.agentStates); this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents); this.resumabilityConfig = context.resumabilityConfig; + this.eventsCompactionConfig = context.eventsCompactionConfig; this.invocationCostManager = context.invocationCostManager; } @@ -446,6 +455,7 @@ private Builder(InvocationContext context) { private Map agentStates = new ConcurrentHashMap<>(); private Map endOfAgents = new ConcurrentHashMap<>(); private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); + @Nullable private EventsCompactionConfig eventsCompactionConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); /** @@ -670,6 +680,18 @@ public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { return this; } + /** + * Sets the events compaction configuration for the current agent run. + * + * @param eventsCompactionConfig the events compaction configuration. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder eventsCompactionConfig(@Nullable EventsCompactionConfig eventsCompactionConfig) { + this.eventsCompactionConfig = eventsCompactionConfig; + return this; + } + /** * Builds the {@link InvocationContext} instance. * @@ -705,6 +727,7 @@ public boolean equals(Object o) { && Objects.equals(agentStates, that.agentStates) && Objects.equals(endOfAgents, that.endOfAgents) && Objects.equals(resumabilityConfig, that.resumabilityConfig) + && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) && Objects.equals(invocationCostManager, that.invocationCostManager); } @@ -727,6 +750,7 @@ public int hashCode() { agentStates, endOfAgents, resumabilityConfig, + eventsCompactionConfig, invocationCostManager); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 574c3dcf0..66bb58606 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -634,6 +634,7 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(rootAgent) .session(session) .resumabilityConfig(this.resumabilityConfig) + .eventsCompactionConfig(this.eventsCompactionConfig) .agent(this.findAgentToRun(session, rootAgent)); } From e0fd53c462169637b616fe20103d0a449e47f2df Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 29 Jan 2026 06:43:00 -0800 Subject: [PATCH 05/63] ADK changes PiperOrigin-RevId: 862713740 --- .../com/google/adk/a2a/AgentExecutor.java | 185 ++++++++++++++++++ .../adk/a2a/converters/EventConverter.java | 13 ++ .../adk/a2a/converters/PartConverter.java | 19 +- 3 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java new file mode 100644 index 000000000..6df01694a --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java @@ -0,0 +1,185 @@ +package com.google.adk.a2a; + +import com.google.adk.a2a.converters.EventConverter; +import com.google.adk.a2a.converters.PartConverter; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.InvalidAgentResponseError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.disposables.CompositeDisposable; +import io.reactivex.rxjava3.disposables.Disposable; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. + * + *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not + * use in production code. + */ +public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { + + private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); + private static final String USER_ID_PREFIX = "A2A_USER_"; + private static final RunConfig DEFAULT_RUN_CONFIG = + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); + + private final Runner runner; + private final Map activeTasks = new ConcurrentHashMap<>(); + + private AgentExecutor(Runner runner) { + this.runner = runner; + } + + /** Builder for {@link AgentExecutor}. */ + public static class Builder { + private Runner runner; + + @CanIgnoreReturnValue + public Builder runner(Runner runner) { + this.runner = runner; + return this; + } + + @CanIgnoreReturnValue + public AgentExecutor build() { + if (runner == null) { + throw new IllegalStateException("Runner must be provided."); + } + return new AgentExecutor(runner); + } + } + + @Override + public void cancel(RequestContext ctx, EventQueue eventQueue) { + TaskUpdater updater = new TaskUpdater(ctx, eventQueue); + updater.cancel(); + cleanupTask(ctx.getTaskId()); + } + + @Override + public void execute(RequestContext ctx, EventQueue eventQueue) { + TaskUpdater updater = new TaskUpdater(ctx, eventQueue); + Message message = ctx.getMessage(); + if (message == null) { + throw new IllegalArgumentException("Message cannot be null"); + } + + // Submits a new task if there is no active task. + if (ctx.getTask() == null) { + updater.submit(); + } + + // Group all reactive work for this task into one container + CompositeDisposable taskDisposables = new CompositeDisposable(); + // Check if the task with the task id is already running, put if absent. + if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { + throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); + } + + EventProcessor p = new EventProcessor(); + Content content = PartConverter.messageToContent(message); + + taskDisposables.add( + prepareSession(ctx, runner.sessionService()) + .flatMapPublisher( + session -> { + updater.startWork(); + return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG); + }) + .subscribe( + event -> { + p.process(event, updater); + }, + error -> { + logger.error("Runner failed with {}", error); + updater.fail(failedMessage(ctx, error)); + cleanupTask(ctx.getTaskId()); + }, + () -> { + updater.complete(); + cleanupTask(ctx.getTaskId()); + })); + } + + private void cleanupTask(String taskId) { + Disposable d = activeTasks.remove(taskId); + if (d != null) { + d.dispose(); // Stops all streams in the CompositeDisposable + } + } + + private String getUserId(RequestContext ctx) { + return USER_ID_PREFIX + ctx.getContextId(); + } + + private Maybe prepareSession(RequestContext ctx, BaseSessionService service) { + return service + .getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty()) + .switchIfEmpty( + Maybe.defer( + () -> { + return service.createSession(runner.appName(), getUserId(ctx)).toMaybe(); + })); + } + + private static Message failedMessage(RequestContext context, Throwable e) { + return new Message.Builder() + .messageId(UUID.randomUUID().toString()) + .contextId(context.getContextId()) + .taskId(context.getTaskId()) + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart(e.getMessage()))) + .build(); + } + + // Processor that will process all events related to the one runner invocation. + private static class EventProcessor { + + // All artifacts related to the invocation should have the same artifact id. + private EventProcessor() { + artifactId = UUID.randomUUID().toString(); + } + + private final String artifactId; + + private void process(Event event, TaskUpdater updater) { + if (event.errorCode().isPresent()) { + throw new InvalidAgentResponseError( + null, // Uses default code -32006 + "Agent returned an error: " + event.errorCode().get(), + null); + } + + ImmutableList> parts = EventConverter.contentToParts(event.content()); + + // Mark all parts as partial if the event is partial. + if (event.partial().orElse(false)) { + parts.forEach( + part -> { + Map metadata = part.getMetadata(); + metadata.put("adk_partial", true); + }); + } + + updater.addArtifact(parts, artifactId, null, ImmutableMap.of()); + } + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index cd8bcefb0..f5b1178c0 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -1,7 +1,10 @@ package com.google.adk.a2a.converters; +import static com.google.common.collect.ImmutableList.toImmutableList; + import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; +import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.a2a.spec.Message; @@ -37,6 +40,16 @@ public enum AggregationMode { EXTERNAL_HANDOFF } + public static ImmutableList> contentToParts(Optional content) { + if (content.isPresent() && content.get().parts().isPresent()) { + return content.get().parts().get().stream() + .map(PartConverter::fromGenaiPart) + .flatMap(Optional::stream) + .collect(toImmutableList()); + } + return ImmutableList.of(); + } + public static Optional convertEventsToA2AMessage(InvocationContext context) { return convertEventsToA2AMessage(context, AggregationMode.AS_IS); } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 0b5ea5503..c6ef06400 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.Content; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -16,6 +17,7 @@ import io.a2a.spec.FilePart; import io.a2a.spec.FileWithBytes; import io.a2a.spec.FileWithUri; +import io.a2a.spec.Message; import io.a2a.spec.TextPart; import java.util.Base64; import java.util.HashMap; @@ -181,6 +183,17 @@ private static Optional convertDataPartToGenAiPart( } } + /** + * Converts an A2A Message to a Google GenAI Content object. + * + * @param message The A2A Message to convert. + * @return The converted Google GenAI Content object. + */ + public static Content messageToContent(Message message) { + ImmutableList parts = toGenaiParts(message.getParts()); + return Content.builder().role("user").parts(parts).build(); + } + /** * Creates an A2A DataPart from a Google GenAI FunctionResponse. * @@ -227,7 +240,7 @@ public static Optional> fromGenaiPart(Part part) { } if (part.text().isPresent()) { - return Optional.of(new TextPart(part.text().get())); + return Optional.of(new TextPart(part.text().get(), new HashMap<>())); } if (part.fileData().isPresent()) { @@ -235,7 +248,7 @@ public static Optional> fromGenaiPart(Part part) { String uri = fileData.fileUri().orElse(null); String mime = fileData.mimeType().orElse(null); String name = fileData.displayName().orElse(null); - return Optional.of(new FilePart(new FileWithUri(mime, name, uri))); + return Optional.of(new FilePart(new FileWithUri(mime, name, uri), new HashMap<>())); } if (part.inlineData().isPresent()) { @@ -244,7 +257,7 @@ public static Optional> fromGenaiPart(Part part) { String encoded = bytes != null ? Base64.getEncoder().encodeToString(bytes) : null; String mime = blob.mimeType().orElse(null); String name = blob.displayName().orElse(null); - return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded))); + return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded), new HashMap<>())); } if (part.functionCall().isPresent() || part.functionResponse().isPresent()) { From 3e74c9a960cba6582e914d36925516039d57913c Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 29 Jan 2026 10:12:33 -0800 Subject: [PATCH 06/63] feat: Refine bug and feature request issue templates PiperOrigin-RevId: 862788975 --- .github/ISSUE_TEMPLATE/bug_report.md | 72 ++++++++++++++++++----- .github/ISSUE_TEMPLATE/feature_request.md | 39 +++++++++--- 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 1abd4a94a..f2f55f079 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -6,30 +6,72 @@ labels: '' assignees: '' --- - -** Please make sure you read the contribution guide and file the issues in the rigth place. ** +**Please make sure you read the contribution guide and file the issues in the + right place. ** [Contribution guide.](https://google.github.io/adk-docs/contributing-guide/) -**Describe the bug** +## 🔴 Required Information +*Please ensure all items in this section are completed to allow for efficient +triaging. Requests without complete information may be rejected / deprioritized. +If an item is not applicable to you - please mark it as N/A + +**Describe the Bug:** A clear and concise description of what the bug is. -**To Reproduce** -Steps to reproduce the behavior: +**Steps to Reproduce:** +Please provide a numbered list of steps to reproduce the behavior: 1. Install '...' 2. Run '....' 3. Open '....' -4. See error +4. Provide error or stacktrace -**Expected behavior** +**Expected Behavior:** A clear and concise description of what you expected to happen. -**Screenshots** -If applicable, add screenshots to help explain your problem. +**Observed Behavior:** +What actually happened? Include error messages or crash stack traces here. + +**Environment Details:** + + - ADK Library Version (see maven dependency): + - OS: [e.g., macOS, Linux, Windows] + - TS Version (tsc --version): + +**Model Information:** + + - Which model is being used: (e.g., gemini-2.5-pro) + +--- + +## 🟡 Optional Information +*Providing this information greatly speeds up the resolution process.* + +**Regression:** +Did this work in a previous version of ADK? (Yes/No) If so, which one? + +**Logs:** +Please attach relevant logs. Wrap them in code blocks (```) or attach a +text file. +```text +// Paste logs here +``` + +**Screenshots / Video:** +If applicable, add screenshots or screen recordings to help explain +your problem. + +**Additional Context:** +Add any other context about the problem here. + +**Minimal Reproduction Code:** +Please provide a code snippet or a link to a Gist/repo that isolates the issue. +``` +// Code snippet here +``` -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Java version: - - ADK version(see maven dependency): +**How often has this issue occurred?:** -**Additional context** -Add any other context about the problem here. \ No newline at end of file + - Always (100%) + - Often (50%+) + - Intermittently (<50%) + - Once / Rare \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index facf41f1e..6af25148f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -10,14 +10,39 @@ assignees: '' ** Please make sure you read the contribution guide and file the issues in the right place. ** [Contribution guide.](https://google.github.io/adk-docs/contributing-guide/) -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] +## 🔴 Required Information +*Please ensure all items in this section are completed to allow for efficient +triaging. Requests without complete information may be rejected / deprioritized. +If an item is not applicable to you - please mark it as N/A* -**Describe the solution you'd like** -A clear and concise description of what you want to happen. +### Is your feature request related to a specific problem? +Please describe the problem you are trying to solve. (Ex: "I'm always frustrated +when I have to manually handle X...") -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. +### Describe the Solution You'd Like +A clear and concise description of the feature or API change you want. +Be specific about input/outputs if this involves an API change. -**Additional context** +### Impact on your work +How does this feature impact your work and what are you trying to achieve? +If this is critical for you, tell us if there is a timeline by when you need +this feature. + +### Willingness to contribute +Are you interested in implementing this feature yourself or submitting a PR? +(Yes/No) + +--- + +## 🟡 Recommended Information + +### Describe Alternatives You've Considered +A clear and concise description of any alternative solutions or workarounds +you've considered and why they didn't work for you. + +### Proposed API / Implementation +If you have ideas on how this should look in code, please share a +pseudo-code example. + +### Additional Context Add any other context or screenshots about the feature request here. \ No newline at end of file From 7019d39e490cef1b4b443d1755547a3a701bc964 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 29 Jan 2026 11:14:23 -0800 Subject: [PATCH 07/63] fix: recursively extract input/output schema for AgentTool PiperOrigin-RevId: 862818185 --- .../java/com/google/adk/tools/AgentTool.java | 46 +++-- .../com/google/adk/tools/AgentToolTest.java | 173 +++++++++++++++++- 2 files changed, 206 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 1a8dbc527..2a50605a1 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -37,6 +37,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -83,15 +84,42 @@ BaseAgent getAgent() { return agent; } + private Optional getInputSchema(BaseAgent agent) { + BaseAgent currentAgent = agent; + while (true) { + if (currentAgent instanceof LlmAgent llmAgent) { + return llmAgent.inputSchema(); + } + List subAgents = currentAgent.subAgents(); + if (subAgents == null || subAgents.isEmpty()) { + return Optional.empty(); + } + // For composite agents, check the first sub-agent. + currentAgent = subAgents.get(0); + } + } + + private Optional getOutputSchema(BaseAgent agent) { + BaseAgent currentAgent = agent; + while (true) { + if (currentAgent instanceof LlmAgent llmAgent) { + return llmAgent.outputSchema(); + } + List subAgents = currentAgent.subAgents(); + if (subAgents == null || subAgents.isEmpty()) { + return Optional.empty(); + } + // For composite agents, check the last sub-agent. + currentAgent = subAgents.get(subAgents.size() - 1); + } + } + @Override public Optional declaration() { FunctionDeclaration.Builder builder = FunctionDeclaration.builder().description(this.description()).name(this.name()); - Optional agentInputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentInputSchema = llmAgent.inputSchema(); - } + Optional agentInputSchema = getInputSchema(agent); if (agentInputSchema.isPresent()) { builder.parameters(agentInputSchema.get()); @@ -113,10 +141,7 @@ public Single> runAsync(Map args, ToolContex toolContext.setActions(toolContext.actions().toBuilder().skipSummarization(true).build()); } - Optional agentInputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentInputSchema = llmAgent.inputSchema(); - } + Optional agentInputSchema = getInputSchema(agent); final Content content; if (agentInputSchema.isPresent()) { @@ -163,10 +188,7 @@ public Single> runAsync(Map args, ToolContex } String output = outputText.get(); - Optional agentOutputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentOutputSchema = llmAgent.outputSchema(); - } + Optional agentOutputSchema = getOutputSchema(agent); if (agentOutputSchema.isPresent()) { return SchemaUtils.validateOutputSchema(output, agentOutputSchema.get()); diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index d43d9d03a..c961e654a 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -21,10 +21,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; import com.google.adk.models.LlmResponse; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; @@ -451,7 +453,176 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio assertThat(toolContext.state()).containsEntry("test_key", "test_value"); } - private static ToolContext createToolContext(LlmAgent agent) { + @Test + public void + declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() { + Schema inputSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "query", + Schema.builder().type("STRING").build(), + "language", + Schema.builder().type("STRING").build())) + .required(ImmutableList.of("query", "language")) + .build(); + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .inputSchema(inputSchema) + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("sequence"); + assertThat(declaration.description().get()) + .isEqualTo("Process the query through multiple steps"); + assertThat(declaration.parameters().get()).isEqualTo(inputSchema); + } + + @Test + public void declaration_sequentialAgentWithoutInputSchema_fallsBackToRequest() { + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("sequence"); + assertThat(declaration.description().get()) + .isEqualTo("Process the query through multiple steps"); + assertThat(declaration.parameters().get()) + .isEqualTo( + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("request", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("request")) + .build()); + } + + @Test + public void call_sequentialAgentWithLastSubAgentOutputSchema_successful() throws Exception { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "is_valid", + Schema.builder().type("BOOLEAN").build(), + "message", + Schema.builder().type("STRING").build())) + .required(ImmutableList.of("is_valid", "message")) + .build(); + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .build(); + LlmAgent secondAgent = + createTestAgentBuilder( + createTestLlm( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromText( + "{\"is_valid\": true, " + "\"message\": \"success\"}"))) + .build())) + .name("second_agent") + .outputSchema(outputSchema) + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + ToolContext toolContext = createToolContext(sequentialAgent); + + Map result = + agentTool.runAsync(ImmutableMap.of("request", "test"), toolContext).blockingGet(); + + assertThat(result).containsExactly("is_valid", true, "message", "success"); + } + + @Test + public void declaration_nestedSequentialAgentInputSchema_returnsDeclarationWithSchema() { + Schema inputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("deep_query", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("deep_query")) + .build(); + LlmAgent innerAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("inner_agent") + .inputSchema(inputSchema) + .build(); + SequentialAgent innerSequence = + SequentialAgent.builder() + .name("inner_sequence") + .subAgents(ImmutableList.of(innerAgent)) + .build(); + SequentialAgent outerSequence = + SequentialAgent.builder() + .name("outer_sequence") + .description("Nested sequence") + .subAgents(ImmutableList.of(innerSequence)) + .build(); + AgentTool agentTool = AgentTool.create(outerSequence); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("outer_sequence"); + assertThat(declaration.parameters().get()).isEqualTo(inputSchema); + } + + @Test + public void declaration_emptySequentialAgent_fallsBackToRequest() { + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("empty_sequence") + .description("An empty sequence") + .subAgents(ImmutableList.of()) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("empty_sequence"); + assertThat(declaration.parameters().get()) + .isEqualTo( + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("request", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("request")) + .build()); + } + + private static ToolContext createToolContext(BaseAgent agent) { return ToolContext.builder( InvocationContext.builder() .invocationId(InvocationContext.newInvocationContextId()) From 577072c878edfe53480bd3cfa44354ea228c8e0b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 29 Jan 2026 23:16:58 -0800 Subject: [PATCH 08/63] refactor: LlmAgent: Unwrap List from Optional, improve test coverage PiperOrigin-RevId: 863081612 --- .../java/com/google/adk/agents/LlmAgent.java | 60 +++++++++---------- .../adk/agents/ConfigAgentUtilsTest.java | 31 +++++----- .../com/google/adk/agents/LlmAgentTest.java | 14 +++++ 3 files changed, 60 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 444985971..1f16d7c00 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNullElse; import static java.util.stream.Collectors.joining; import com.fasterxml.jackson.core.JsonProcessingException; @@ -103,12 +104,12 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; - private final Optional> beforeModelCallback; - private final Optional> afterModelCallback; - private final Optional> onModelErrorCallback; - private final Optional> beforeToolCallback; - private final Optional> afterToolCallback; - private final Optional> onToolErrorCallback; + private final ImmutableList beforeModelCallback; + private final ImmutableList afterModelCallback; + private final ImmutableList onModelErrorCallback; + private final ImmutableList beforeToolCallback; + private final ImmutableList afterToolCallback; + private final ImmutableList onToolErrorCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -126,29 +127,28 @@ protected LlmAgent(Builder builder) { builder.beforeAgentCallback, builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); - this.instruction = - builder.instruction == null ? new Instruction.Static("") : builder.instruction; + this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static("")); this.globalInstruction = - builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction; + requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); this.exampleProvider = Optional.ofNullable(builder.exampleProvider); - this.includeContents = - builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT; + this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; - this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); - this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); - this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback); - this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); - this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); - this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback); + this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of()); + this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of()); + this.onModelErrorCallback = + requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of()); + this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of()); + this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of()); + this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of()); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); this.outputKey = Optional.ofNullable(builder.outputKey); - this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of(); + this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of()); this.toolsets = extractToolsets(this.toolsUnion); this.codeExecutor = Optional.ofNullable(builder.codeExecutor); @@ -841,27 +841,27 @@ public boolean disallowTransferToPeers() { return disallowTransferToPeers; } - public Optional> beforeModelCallback() { + public List beforeModelCallback() { return beforeModelCallback; } - public Optional> afterModelCallback() { + public List afterModelCallback() { return afterModelCallback; } - public Optional> beforeToolCallback() { + public List beforeToolCallback() { return beforeToolCallback; } - public Optional> afterToolCallback() { + public List afterToolCallback() { return afterToolCallback; } - public Optional> onModelErrorCallback() { + public List onModelErrorCallback() { return onModelErrorCallback; } - public Optional> onToolErrorCallback() { + public List onToolErrorCallback() { return onToolErrorCallback; } @@ -871,7 +871,7 @@ public Optional> onToolErrorCallback() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeModelCallbacks() { - return beforeModelCallback.orElse(ImmutableList.of()); + return beforeModelCallback; } /** @@ -880,7 +880,7 @@ public List canonicalBeforeModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterModelCallbacks() { - return afterModelCallback.orElse(ImmutableList.of()); + return afterModelCallback; } /** @@ -889,7 +889,7 @@ public List canonicalAfterModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnModelErrorCallbacks() { - return onModelErrorCallback.orElse(ImmutableList.of()); + return onModelErrorCallback; } /** @@ -898,7 +898,7 @@ public List canonicalOnModelErrorCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeToolCallbacks() { - return beforeToolCallback.orElse(ImmutableList.of()); + return beforeToolCallback; } /** @@ -907,7 +907,7 @@ public List canonicalBeforeToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterToolCallbacks() { - return afterToolCallback.orElse(ImmutableList.of()); + return afterToolCallback; } /** @@ -916,7 +916,7 @@ public List canonicalAfterToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnToolErrorCallbacks() { - return onToolErrorCallback.orElse(ImmutableList.of()); + return onToolErrorCallback; } public Optional inputSchema() { diff --git a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java index 4825efca1..11e07a094 100644 --- a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java +++ b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java @@ -1161,20 +1161,25 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() String pfx = "test.callbacks."; registry.register( - pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); + pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); - registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty()); + pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty()); + pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty()); + pfx + "before_model_1", + (Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty()); + registry.register( + pfx + "after_model_1", + (Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty()); registry.register( pfx + "before_tool_1", - (Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty()); + (Callbacks.BeforeToolCallback) + (unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty()); registry.register( pfx + "after_tool_1", - (Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty()); + (Callbacks.AfterToolCallback) + (unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty()); File configFile = tempFolder.newFile("with_callbacks.yaml"); Files.writeString( @@ -1209,15 +1214,11 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() assertThat(agent.afterAgentCallback()).isPresent(); assertThat(agent.afterAgentCallback().get()).hasSize(1); - assertThat(llm.beforeModelCallback()).isPresent(); - assertThat(llm.beforeModelCallback().get()).hasSize(1); - assertThat(llm.afterModelCallback()).isPresent(); - assertThat(llm.afterModelCallback().get()).hasSize(1); + assertThat(llm.beforeModelCallback()).hasSize(1); + assertThat(llm.afterModelCallback()).hasSize(1); - assertThat(llm.beforeToolCallback()).isPresent(); - assertThat(llm.beforeToolCallback().get()).hasSize(1); - assertThat(llm.afterToolCallback()).isPresent(); - assertThat(llm.afterToolCallback().get()).hasSize(1); + assertThat(llm.beforeToolCallback()).hasSize(1); + assertThat(llm.afterToolCallback()).hasSize(1); } @Test diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 519c90558..8a2ff6df8 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -341,6 +341,13 @@ public void canonicalCallbacks_returnsEmptyListWhenNull() { assertThat(agent.canonicalBeforeToolCallbacks()).isEmpty(); assertThat(agent.canonicalAfterToolCallbacks()).isEmpty(); assertThat(agent.canonicalOnToolErrorCallbacks()).isEmpty(); + + assertThat(agent.beforeModelCallback()).isEmpty(); + assertThat(agent.afterModelCallback()).isEmpty(); + assertThat(agent.onModelErrorCallback()).isEmpty(); + assertThat(agent.beforeToolCallback()).isEmpty(); + assertThat(agent.afterToolCallback()).isEmpty(); + assertThat(agent.onToolErrorCallback()).isEmpty(); } @Test @@ -371,5 +378,12 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalBeforeToolCallbacks()).containsExactly(btc); assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc); assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec); + + assertThat(agent.beforeModelCallback()).containsExactly(bmc); + assertThat(agent.afterModelCallback()).containsExactly(amc); + assertThat(agent.onModelErrorCallback()).containsExactly(omec); + assertThat(agent.beforeToolCallback()).containsExactly(btc); + assertThat(agent.afterToolCallback()).containsExactly(atc); + assertThat(agent.onToolErrorCallback()).containsExactly(otec); } } From 588b00bbd327e257a78271bf2d929bc52875115f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 29 Jan 2026 23:32:03 -0800 Subject: [PATCH 09/63] feat: Add tokenThreshold and eventRetentionSize to EventsCompactionConfig This interface change allows intra-invocation compaction with configurable trigger threshold and number of events to be preserved from being compacted. PiperOrigin-RevId: 863086157 --- .../adk/summarizer/EventsCompactionConfig.java | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java index 766041304..04dc11e10 100644 --- a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java +++ b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java @@ -27,11 +27,24 @@ * compacted range. This creates an overlap between consecutive compacted summaries, maintaining * context. * @param summarizer An event summarizer to use for compaction. + * @param tokenThreshold The number of tokens above which compaction will be triggered. If null, no + * token limit will be enforced. It will trigger compaction within the invocation. + * @param eventRetentionSize The maximum number of events to retain and preserve from compaction. If + * null, no event retention limit will be enforced. */ public record EventsCompactionConfig( - int compactionInterval, int overlapSize, @Nullable BaseEventSummarizer summarizer) { + int compactionInterval, + int overlapSize, + @Nullable BaseEventSummarizer summarizer, + @Nullable Integer tokenThreshold, + @Nullable Integer eventRetentionSize) { public EventsCompactionConfig(int compactionInterval, int overlapSize) { - this(compactionInterval, overlapSize, null); + this(compactionInterval, overlapSize, null, null, null); + } + + public EventsCompactionConfig( + int compactionInterval, int overlapSize, @Nullable BaseEventSummarizer summarizer) { + this(compactionInterval, overlapSize, summarizer, null, null); } } From 22f8ac8b9d62611eaf9f545e6e8ee376709faf8c Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Feb 2026 00:59:53 -0800 Subject: [PATCH 10/63] ADK changes PiperOrigin-RevId: 864223068 --- contrib/samples/a2a_basic/A2AAgent.java | 6 ++---- contrib/samples/a2a_basic/A2AAgentRun.java | 12 +++++------- .../adk/models/springai/SpringAIIntegrationTest.java | 3 ++- .../integrations/AnthropicApiIntegrationTest.java | 6 +++--- .../integrations/GeminiApiIntegrationTest.java | 6 +++--- .../integrations/OpenAiApiIntegrationTest.java | 6 +++--- 6 files changed, 18 insertions(+), 21 deletions(-) diff --git a/contrib/samples/a2a_basic/A2AAgent.java b/contrib/samples/a2a_basic/A2AAgent.java index fa4932cf9..e4e79a4eb 100644 --- a/contrib/samples/a2a_basic/A2AAgent.java +++ b/contrib/samples/a2a_basic/A2AAgent.java @@ -1,8 +1,5 @@ package com.example.a2a_basic; -import java.util.ArrayList; -import java.util.Random; - import com.google.adk.a2a.RemoteA2AAgent; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.LlmAgent; @@ -10,7 +7,6 @@ import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - import io.a2a.client.Client; import io.a2a.client.config.ClientConfig; import io.a2a.client.http.A2ACardResolver; @@ -18,6 +14,8 @@ import io.a2a.client.transport.jsonrpc.JSONRPCTransport; import io.a2a.client.transport.jsonrpc.JSONRPCTransportConfig; import io.a2a.spec.AgentCard; +import java.util.ArrayList; +import java.util.Random; /** Provides local roll logic plus a remote A2A agent for the demo. */ public final class A2AAgent { diff --git a/contrib/samples/a2a_basic/A2AAgentRun.java b/contrib/samples/a2a_basic/A2AAgentRun.java index 12d515f62..b4a63b026 100644 --- a/contrib/samples/a2a_basic/A2AAgentRun.java +++ b/contrib/samples/a2a_basic/A2AAgentRun.java @@ -1,11 +1,5 @@ package com.example.a2a_basic; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.TimeUnit; - import com.google.adk.agents.BaseAgent; import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.InMemoryArtifactService; @@ -15,8 +9,12 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; - import io.reactivex.rxjava3.core.Flowable; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; /** Main class to demonstrate running the A2A agent with sequential inputs. */ public final class A2AAgentRun { diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java index a0db38550..6843c8eaa 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java @@ -123,7 +123,8 @@ public ChatResponse call(Prompt prompt) { // Subsequent calls - provide final answer message = new AssistantMessage( - "The weather in Paris is beautiful and sunny with temperatures from 10°C in the morning up to 24°C in the afternoon."); + "The weather in Paris is beautiful and sunny with temperatures from 10°C in" + + " the morning up to 24°C in the afternoon."); } Generation generation = new Generation(message); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java index 1e2ac82f6..f21b07ae9 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java @@ -145,9 +145,9 @@ void testAgentWithToolsAndRealApi() { .model(new SpringAI(anthropicModel, CLAUDE_MODEL)) .instruction( """ - You are a helpful assistant. - When asked about weather, you MUST use the getWeatherInfo function to get current conditions. - """) + You are a helpful assistant. + When asked about weather, you MUST use the getWeatherInfo function to get current conditions. + """) .tools(FunctionTool.create(WeatherTools.class, "getWeatherInfo")) .build(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java index 47f3d3b11..5414cdf99 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java @@ -157,9 +157,9 @@ void testAgentWithToolsAndRealApi() { .model(new SpringAI(geminiModel, GEMINI_MODEL)) .instruction( """ - You are a helpful assistant. - When asked about weather, you MUST use the getWeatherInfo function to get current conditions. - """) + You are a helpful assistant. + When asked about weather, you MUST use the getWeatherInfo function to get current conditions. + """) .tools(FunctionTool.create(WeatherTools.class, "getWeatherInfo")) .build(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java index 894ffaba6..eb956f828 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java @@ -133,9 +133,9 @@ void testAgentWithToolsAndRealApi() { .model(new SpringAI(openAiModel, GPT_MODEL)) .instruction( """ - You are a helpful assistant. - When asked about weather, use the getWeatherInfo function to get current conditions. - """) + You are a helpful assistant. + When asked about weather, use the getWeatherInfo function to get current conditions. + """) .tools(FunctionTool.create(WeatherTools.class, "getWeatherInfo")) .build(); From d11bedf42976242d1c3dd6b99ebae0babe59535c Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Feb 2026 01:02:30 -0800 Subject: [PATCH 11/63] fix: Fixing a regression in InMemorySessionService PiperOrigin-RevId: 864224090 --- .../adk/sessions/InMemorySessionService.java | 2 +- .../com/google/adk/agents/LlmAgentTest.java | 55 +++++++++++++ .../sessions/InMemorySessionServiceTest.java | 80 ++++++++++++++----- .../com/google/adk/testing/TestUtils.java | 15 ++++ 4 files changed, 129 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b658f6767..060fcaf60 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -259,7 +259,7 @@ public Single appendEvent(Session session, Event event) { .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(userStateKey, value); } - } else if (!key.startsWith(State.TEMP_PREFIX)) { + } else { if (value == State.REMOVED) { session.state().remove(key); } else { diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 8a2ff6df8..ae50b5b8e 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -34,8 +34,11 @@ import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; import com.google.adk.models.LlmRegistry; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.models.Model; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; @@ -49,6 +52,7 @@ import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -386,4 +390,55 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.afterToolCallback()).containsExactly(atc); assertThat(agent.onToolErrorCallback()).containsExactly(otec); } + + @Test + public void run_sequentialAgents_shareTempStateViaSession() { + // 1. Setup Session Service and Session + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + // 2. Agent 1: runs and produces output "value1" to state "temp:key1" + Content model1Content = Content.fromParts(Part.fromText("value1")); + TestLlm testLlm1 = createTestLlm(createLlmResponse(model1Content)); + LlmAgent agent1 = + createTestAgentBuilder(testLlm1).name("agent1").outputKey("temp:key1").build(); + InvocationContext invocationContext1 = createInvocationContext(agent1, sessionService, session); + + List events1 = agent1.runAsync(invocationContext1).toList().blockingGet(); + assertThat(events1).hasSize(1); + Event event1 = events1.get(0); + assertThat(event1.actions()).isNotNull(); + assertThat(event1.actions().stateDelta()).containsEntry("temp:key1", "value1"); + + // 3. Simulate orchestrator: append event1 to session, updating its state + var unused = sessionService.appendEvent(session, event1).blockingGet(); + assertThat(session.state()).containsEntry("temp:key1", "value1"); + + // 4. Agent 2: uses Instruction.Provider to read "temp:key1" from session state + // and generates an instruction based on it. + TestLlm testLlm2 = + createTestLlm(createLlmResponse(Content.fromParts(Part.fromText("response2")))); + LlmAgent agent2 = + createTestAgentBuilder(testLlm2) + .name("agent2") + .instruction( + new Instruction.Provider( + ctx -> + Single.just( + "Instruction for Agent2 based on Agent1 output: " + + ctx.state().get("temp:key1")))) + .build(); + InvocationContext invocationContext2 = createInvocationContext(agent2, sessionService, session); + List events2 = agent2.runAsync(invocationContext2).toList().blockingGet(); + assertThat(events2).hasSize(1); + + // 5. Verify that agent2's LLM received an instruction containing agent1's output + assertThat(testLlm2.getRequests()).hasSize(1); + LlmRequest request2 = testLlm2.getRequests().get(0); + assertThat(request2.getFirstSystemInstruction().get()) + .contains("Instruction for Agent2 based on Agent1 output: value1"); + } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 97b182249..6223dd2f0 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -89,8 +89,8 @@ public void lifecycle_listSessions() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("app:appKey", "appValue"); - stateDelta.put("user:userKey", "userValue"); + stateDelta.put("_app_appKey", "appValue"); + stateDelta.put("_user_userKey", "userValue"); stateDelta.put("temp:tempKey", "tempValue"); Event event = @@ -106,9 +106,9 @@ public void lifecycle_listSessions() { assertThat(listedSession.id()).isEqualTo(session.id()); assertThat(listedSession.events()).isEmpty(); assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(listedSession.state()).containsEntry("app:appKey", "appValue"); - assertThat(listedSession.state()).containsEntry("user:userKey", "userValue"); - assertThat(listedSession.state()).doesNotContainKey("temp:tempKey"); + assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue"); + assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(listedSession.state()).containsEntry("temp:tempKey", "tempValue"); } @Test @@ -136,8 +136,8 @@ public void appendEvent_updatesSessionState() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("app:appKey", "appValue"); - stateDelta.put("user:userKey", "userValue"); + stateDelta.put("_app_appKey", "appValue"); + stateDelta.put("_user_userKey", "userValue"); stateDelta.put("temp:tempKey", "tempValue"); Event event = @@ -148,9 +148,9 @@ public void appendEvent_updatesSessionState() { // After appendEvent, session state in memory should contain session-specific state from delta // and merged global state. assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).containsEntry("app:appKey", "appValue"); - assertThat(session.state()).containsEntry("user:userKey", "userValue"); - assertThat(session.state()).doesNotContainKey("temp:tempKey"); + assertThat(session.state()).containsEntry("_app_appKey", "appValue"); + assertThat(session.state()).containsEntry("_user_userKey", "userValue"); + assertThat(session.state()).containsEntry("temp:tempKey", "tempValue"); // getSession should return session with merged state. Session retrievedSession = @@ -158,9 +158,9 @@ public void appendEvent_updatesSessionState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(retrievedSession.state()).containsEntry("app:appKey", "appValue"); - assertThat(retrievedSession.state()).containsEntry("user:userKey", "userValue"); - assertThat(retrievedSession.state()).doesNotContainKey("temp:tempKey"); + assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSession.state()).containsEntry("temp:tempKey", "tempValue"); } @Test @@ -173,8 +173,8 @@ public void appendEvent_removesState() { ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); stateDeltaAdd.put("sessionKey", "sessionValue"); - stateDeltaAdd.put("app:appKey", "appValue"); - stateDeltaAdd.put("user:userKey", "userValue"); + stateDeltaAdd.put("_app_appKey", "appValue"); + stateDeltaAdd.put("_user_userKey", "userValue"); stateDeltaAdd.put("temp:tempKey", "tempValue"); Event eventAdd = @@ -188,15 +188,15 @@ public void appendEvent_removesState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(retrievedSessionAdd.state()).containsEntry("app:appKey", "appValue"); - assertThat(retrievedSessionAdd.state()).containsEntry("user:userKey", "userValue"); - assertThat(retrievedSessionAdd.state()).doesNotContainKey("temp:tempKey"); + assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue"); // Prepare and append event to remove state ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); stateDeltaRemove.put("sessionKey", State.REMOVED); - stateDeltaRemove.put("app:appKey", State.REMOVED); - stateDeltaRemove.put("user:userKey", State.REMOVED); + stateDeltaRemove.put("_app_appKey", State.REMOVED); + stateDeltaRemove.put("_user_userKey", State.REMOVED); stateDeltaRemove.put("temp:tempKey", State.REMOVED); Event eventRemove = @@ -212,8 +212,44 @@ public void appendEvent_removesState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); - assertThat(retrievedSessionRemove.state()).doesNotContainKey("app:appKey"); - assertThat(retrievedSessionRemove.state()).doesNotContainKey("user:userKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey"); assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + + @Test + public void sequentialAgents_shareTempState() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + // Agent 1 writes to temp state + ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); + stateDelta1.put("temp:agent1_output", "data"); + Event event1 = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build(); + var unused = sessionService.appendEvent(session, event1).blockingGet(); + + // Verify agent 1 output is in session state + assertThat(session.state()).containsEntry("temp:agent1_output", "data"); + + // Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes + // "agent1_output" + ConcurrentMap stateDelta2 = new ConcurrentHashMap<>(); + stateDelta2.put("temp:agent2_output", "processed_data"); + stateDelta2.put("temp:agent1_output", State.REMOVED); + Event event2 = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build(); + unused = sessionService.appendEvent(session, event2).blockingGet(); + + // Verify final state after agent 2 processing + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output"); + assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data"); + } } diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 2bdcf1fbd..df94b76b2 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -30,7 +30,9 @@ import com.google.adk.events.EventCompaction; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; @@ -68,6 +70,19 @@ public static InvocationContext createInvocationContext(BaseAgent agent) { return createInvocationContext(agent, RunConfig.builder().build()); } + public static InvocationContext createInvocationContext( + BaseAgent agent, BaseSessionService sessionService, Session session) { + return InvocationContext.builder() + .sessionService(sessionService) + .artifactService(new InMemoryArtifactService()) + .invocationId("invocationId") + .agent(agent) + .session(session) + .userContent(Content.fromParts(Part.fromText("user content"))) + .runConfig(RunConfig.builder().build()) + .build(); + } + public static Event createEvent(String id) { return Event.builder() .id(id) From c4c2194f7242c5c1cbb4a0cf59ae529ed837d565 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Feb 2026 02:48:49 -0800 Subject: [PATCH 12/63] refactor: updating Exceptions to be more specific in applicationintegrationtoolset PiperOrigin-RevId: 864263802 --- .../IntegrationClient.java | 8 +++--- .../IntegrationConnectorTool.java | 25 ++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java index b1958d56a..3b63429a9 100644 --- a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java +++ b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java @@ -148,7 +148,7 @@ private void validate() { } } - String generateOpenApiSpec() throws Exception { + String generateOpenApiSpec() throws IOException, InterruptedException { String url = String.format( "https://%s-integrations.googleapis.com/v1/projects/%s/locations/%s:generateOpenApiSpec", @@ -179,7 +179,7 @@ String generateOpenApiSpec() throws Exception { httpClient.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); if (response.statusCode() < 200 || response.statusCode() >= 300) { - throw new Exception("Error fetching OpenAPI spec. Status: " + response.statusCode()); + throw new IOException("Error fetching OpenAPI spec. Status: " + response.statusCode()); } return response.body(); } @@ -343,7 +343,7 @@ ObjectNode getOpenApiSpecForConnection(String toolName, String toolInstructions) return connectorSpec; } - String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws Exception { + String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { @@ -372,7 +372,7 @@ String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) thr } } } - throw new Exception("Could not find operationId for pathUrl: " + pathUrl); + throw new IOException("Could not find operationId for pathUrl: " + pathUrl); } ConnectionsClient createConnectionsClient() { diff --git a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java index bc1357106..be93582e7 100644 --- a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java +++ b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java @@ -129,7 +129,7 @@ public class IntegrationConnectorTool extends BaseTool { this.credentialsHelper = Preconditions.checkNotNull(credentialsHelper); } - Schema toGeminiSchema(String openApiSchema, String operationId) throws Exception { + Schema toGeminiSchema(String openApiSchema, String operationId) throws IOException { String resolvedSchemaString = getResolvedRequestSchemaByOperationId(openApiSchema, operationId); return Schema.fromJson(resolvedSchemaString); } @@ -148,7 +148,7 @@ public Optional declaration() { .parameters(parametersSchema) .build(); return Optional.of(declaration); - } catch (Exception e) { + } catch (IOException e) { logger.error("Failed to get OpenAPI spec", e); return Optional.empty(); } @@ -175,20 +175,21 @@ public Single> runAsync(Map args, ToolContex try { String response = executeIntegration(args); return ImmutableMap.of("result", response); - } catch (Exception e) { + } catch (IOException | InterruptedException e) { logger.error("Failed to execute integration", e); return ImmutableMap.of("error", e.getMessage()); } }); } - private String executeIntegration(Map args) throws Exception { + private String executeIntegration(Map args) + throws IOException, InterruptedException { String url = String.format("https://integrations.googleapis.com%s", this.pathUrl); String jsonRequestBody; try { jsonRequestBody = objectMapper.writeValueAsString(args); } catch (IOException e) { - throw new Exception("Error converting args to JSON: " + e.getMessage(), e); + throw new IOException("Error converting args to JSON: " + e.getMessage(), e); } Credentials credentials = credentialsHelper.getGoogleCredentials(this.serviceAccountJson); HttpRequest.Builder requestBuilder = @@ -203,7 +204,7 @@ private String executeIntegration(Map args) throws Exception { httpClient.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); if (response.statusCode() < 200 || response.statusCode() >= 300) { - throw new Exception( + throw new IOException( "Error executing integration. Status: " + response.statusCode() + " , Response: " @@ -212,7 +213,7 @@ private String executeIntegration(Map args) throws Exception { return response.body(); } - String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws Exception { + String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { @@ -254,11 +255,11 @@ String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) thr } } } - throw new Exception("Could not find operationId for pathUrl: " + pathUrl); + throw new IOException("Could not find operationId for pathUrl: " + pathUrl); } private String getResolvedRequestSchemaByOperationId( - String openApiSchemaString, String operationId) throws Exception { + String openApiSchemaString, String operationId) throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { @@ -268,13 +269,13 @@ private String getResolvedRequestSchemaByOperationId( JsonNode rootNode = objectMapper.readTree(specNode.asText()); JsonNode operationNode = findOperationNodeById(rootNode, operationId); if (operationNode == null) { - throw new Exception("Could not find operation with operationId: " + operationId); + throw new IOException("Could not find operation with operationId: " + operationId); } JsonNode requestSchemaNode = operationNode.path("requestBody").path("content").path("application/json").path("schema"); if (requestSchemaNode.isMissingNode()) { - throw new Exception("Could not find request body schema for operationId: " + operationId); + throw new IOException("Could not find request body schema for operationId: " + operationId); } JsonNode resolvedSchema = resolveRefs(requestSchemaNode, rootNode); @@ -355,7 +356,7 @@ private JsonNode resolveRefs(JsonNode currentNode, JsonNode rootNode) { } private String getOperationDescription(String openApiSchemaString, String operationId) - throws Exception { + throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { From e19d20a73508b577f1d26002d6c73f1077cc3176 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Feb 2026 05:32:24 -0800 Subject: [PATCH 13/63] refactor: Runner and App now use `? extends Plugin` PiperOrigin-RevId: 864312028 --- core/src/main/java/com/google/adk/apps/App.java | 12 ++++++------ .../com/google/adk/plugins/PluginManager.java | 17 +++++------------ .../com/google/adk/runner/InMemoryRunner.java | 4 ++-- .../main/java/com/google/adk/runner/Runner.java | 14 +++++++------- 4 files changed, 20 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 3b1f0613a..5be72bb5c 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -17,7 +17,7 @@ package com.google.adk.apps; import com.google.adk.agents.BaseAgent; -import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.Plugin; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -38,14 +38,14 @@ public class App { private final String name; private final BaseAgent rootAgent; - private final ImmutableList plugins; + private final ImmutableList plugins; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ResumabilityConfig resumabilityConfig; private App( String name, BaseAgent rootAgent, - List plugins, + List plugins, @Nullable EventsCompactionConfig eventsCompactionConfig, @Nullable ResumabilityConfig resumabilityConfig) { this.name = name; @@ -63,7 +63,7 @@ public BaseAgent rootAgent() { return rootAgent; } - public ImmutableList plugins() { + public ImmutableList plugins() { return plugins; } @@ -81,7 +81,7 @@ public ResumabilityConfig resumabilityConfig() { public static class Builder { private String name; private BaseAgent rootAgent; - private List plugins = ImmutableList.of(); + private List plugins = ImmutableList.of(); @Nullable private EventsCompactionConfig eventsCompactionConfig; @Nullable private ResumabilityConfig resumabilityConfig; @@ -98,7 +98,7 @@ public Builder rootAgent(BaseAgent rootAgent) { } @CanIgnoreReturnValue - public Builder plugins(List plugins) { + public Builder plugins(List plugins) { this.plugins = plugins; return this; } diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 76452af64..d7ce6b819 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -41,16 +41,14 @@ *

The PluginManager is an internal class that orchestrates the invocation of plugin callbacks at * key points in the SDK's execution lifecycle. */ -public class PluginManager implements Plugin { +public class PluginManager extends BasePlugin { private static final Logger logger = LoggerFactory.getLogger(PluginManager.class); - private final List plugins; + private final List plugins = new ArrayList<>(); public PluginManager(List plugins) { - this.plugins = new ArrayList<>(); + super("PluginManager"); if (plugins != null) { - for (var plugin : plugins) { - this.registerPlugin(plugin); - } + plugins.forEach(this::registerPlugin); } } @@ -58,11 +56,6 @@ public PluginManager() { this(null); } - @Override - public String getName() { - return "PluginManager"; - } - /** * Registers a new plugin. * @@ -259,7 +252,7 @@ private Maybe runMaybeCallbacks( callbackExecutor .apply(plugin) .doOnSuccess( - r -> + unused -> logger.debug( "Plugin '{}' returned a value for callback '{}', exiting " + "early.", diff --git a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java index 793f24c38..58741003c 100644 --- a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java +++ b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java @@ -19,7 +19,7 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.memory.InMemoryMemoryService; -import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.Plugin; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; import java.util.List; @@ -37,7 +37,7 @@ public InMemoryRunner(BaseAgent agent, String appName) { this(agent, appName, ImmutableList.of()); } - public InMemoryRunner(BaseAgent agent, String appName, List plugins) { + public InMemoryRunner(BaseAgent agent, String appName, List plugins) { super( agent, appName, diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 66bb58606..51696f1af 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -30,7 +30,7 @@ import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.Model; -import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.Plugin; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; @@ -84,7 +84,7 @@ public static class Builder { private BaseArtifactService artifactService = new InMemoryArtifactService(); private BaseSessionService sessionService = new InMemorySessionService(); @Nullable private BaseMemoryService memoryService = null; - private List plugins = ImmutableList.of(); + private List plugins = ImmutableList.of(); @CanIgnoreReturnValue public Builder app(App app) { @@ -126,7 +126,7 @@ public Builder memoryService(BaseMemoryService memoryService) { } @CanIgnoreReturnValue - public Builder plugins(List plugins) { + public Builder plugins(List plugins) { Preconditions.checkState(this.app == null, "plugins() cannot be called when app is set."); this.plugins = plugins; return this; @@ -135,7 +135,7 @@ public Builder plugins(List plugins) { public Runner build() { BaseAgent buildAgent; String buildAppName; - List buildPlugins; + List buildPlugins; ResumabilityConfig buildResumabilityConfig; EventsCompactionConfig buildEventsCompactionConfig; @@ -224,7 +224,7 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, - List plugins) { + List plugins) { this( agent, appName, @@ -247,7 +247,7 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, - List plugins, + List plugins, ResumabilityConfig resumabilityConfig) { this( agent, @@ -272,7 +272,7 @@ protected Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, - List plugins, + List plugins, ResumabilityConfig resumabilityConfig, @Nullable EventsCompactionConfig eventsCompactionConfig) { this.agent = agent; From 720262d2414b3dc9c9198c0e56a277f906806805 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Feb 2026 05:49:23 -0800 Subject: [PATCH 14/63] refactor: Using a component provider to break up class dependency cycles reducing cyclical dependencies would help create something akin to "adk-light" PiperOrigin-RevId: 864316690 --- .../utils/AdditionalAdkComponentProvider.java | 53 +++++++++++++ .../adk/utils/AdkComponentProvider.java | 64 ++++++++++++++++ .../google/adk/utils/ComponentRegistry.java | 53 +++++-------- .../adk/utils/CoreAdkComponentProvider.java | 75 +++++++++++++++++++ .../com.google.adk.utils.AdkComponentProvider | 1 + 5 files changed, 213 insertions(+), 33 deletions(-) create mode 100644 core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java create mode 100644 core/src/main/java/com/google/adk/utils/AdkComponentProvider.java create mode 100644 core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java create mode 100644 core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider diff --git a/core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java b/core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java new file mode 100644 index 000000000..c94a18f55 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java @@ -0,0 +1,53 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.utils; + +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.GoogleMapsTool; +import com.google.adk.tools.GoogleSearchTool; +import com.google.adk.tools.mcp.McpToolset; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Provides ADK components that are part of core. */ +public final class AdditionalAdkComponentProvider implements AdkComponentProvider { + + /** + * Returns tool instances for {@link GoogleSearchTool} and {@link GoogleMapsTool}. + * + * @return a map of tool instances. + */ + @Override + public Map getToolInstances() { + Map toolInstances = new HashMap<>(); + toolInstances.put("google_search", GoogleSearchTool.INSTANCE); + toolInstances.put("google_maps_grounding", GoogleMapsTool.INSTANCE); + return toolInstances; + } + + /** + * Returns toolset classes for {@link McpToolset}. + * + * @return a list of toolset classes. + */ + @Override + public List> getToolsetClasses() { + return Arrays.asList(McpToolset.class); + } +} diff --git a/core/src/main/java/com/google/adk/utils/AdkComponentProvider.java b/core/src/main/java/com/google/adk/utils/AdkComponentProvider.java new file mode 100644 index 000000000..173edbb26 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/AdkComponentProvider.java @@ -0,0 +1,64 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.utils; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; +import java.util.Map; + +/** Service provider interface for ADK components to be registered in {@link ComponentRegistry}. */ +public interface AdkComponentProvider { + + /** + * Returns a list of agent classes to register. + * + * @return a list of agent classes. + */ + default List> getAgentClasses() { + return ImmutableList.of(); + } + + /** + * Returns a list of tool classes to register. + * + * @return a list of tool classes. + */ + default List> getToolClasses() { + return ImmutableList.of(); + } + + /** + * Returns a list of toolset classes to register. + * + * @return a list of toolset classes. + */ + default List> getToolsetClasses() { + return ImmutableList.of(); + } + + /** + * Returns a map of tool instances to register, with tool name as key. + * + * @return a map of tool instances. + */ + default Map getToolInstances() { + return ImmutableMap.of(); + } +} diff --git a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java index 0a9f55b16..3b2d0d14a 100644 --- a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java +++ b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java @@ -23,25 +23,13 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.Callbacks; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.LoopAgent; -import com.google.adk.agents.ParallelAgent; -import com.google.adk.agents.SequentialAgent; -import com.google.adk.tools.AgentTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; -import com.google.adk.tools.ExampleTool; -import com.google.adk.tools.ExitLoopTool; -import com.google.adk.tools.GoogleMapsTool; -import com.google.adk.tools.GoogleSearchTool; -import com.google.adk.tools.LoadArtifactsTool; -import com.google.adk.tools.LongRunningFunctionTool; -import com.google.adk.tools.UrlContextTool; -import com.google.adk.tools.mcp.McpToolset; import java.util.Map; import java.util.Optional; +import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nonnull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -103,46 +91,45 @@ protected ComponentRegistry() { /** Initializes the registry with base pre-wired ADK instances. */ private void initializePreWiredEntries() { - registerAdkAgentClass(LlmAgent.class); - registerAdkAgentClass(LoopAgent.class); - registerAdkAgentClass(ParallelAgent.class); - registerAdkAgentClass(SequentialAgent.class); - - registerAdkToolInstance("google_search", GoogleSearchTool.INSTANCE); - registerAdkToolInstance("load_artifacts", LoadArtifactsTool.INSTANCE); - registerAdkToolInstance("exit_loop", ExitLoopTool.INSTANCE); - registerAdkToolInstance("url_context", UrlContextTool.INSTANCE); - registerAdkToolInstance("google_maps_grounding", GoogleMapsTool.INSTANCE); - - registerAdkToolClass(AgentTool.class); - registerAdkToolClass(LongRunningFunctionTool.class); - registerAdkToolClass(ExampleTool.class); - - registerAdkToolsetClass(McpToolset.class); - // TODO: add all python tools that also exist in Java. + // Core components are registered first. + AdkComponentProvider coreProvider = new CoreAdkComponentProvider(); + registerProvider(coreProvider); + ServiceLoader loader = ServiceLoader.load(AdkComponentProvider.class); + for (AdkComponentProvider provider : loader) { + registerProvider(provider); + } logger.debug("Initialized base pre-wired entries in ComponentRegistry"); } + private void registerProvider(AdkComponentProvider provider) { + provider.getAgentClasses().forEach(this::registerAdkAgentClass); + provider.getToolClasses().forEach(this::registerAdkToolClass); + provider.getToolsetClasses().forEach(this::registerAdkToolsetClass); + provider.getToolInstances().forEach(this::registerAdkToolInstance); + logger.info("Registered components from " + provider.getClass().getName()); + } + private void registerAdkAgentClass(Class agentClass) { registry.put(agentClass.getName(), agentClass); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.agents." + agentClass.getSimpleName(), agentClass); } - private void registerAdkToolInstance(String name, @Nonnull Object toolInstance) { + private void registerAdkToolInstance(String name, BaseTool toolInstance) { registry.put(name, toolInstance); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.tools." + name, toolInstance); } - private void registerAdkToolClass(@Nonnull Class toolClass) { + private void registerAdkToolClass(Class toolClass) { registry.put(toolClass.getName(), toolClass); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.tools." + toolClass.getSimpleName(), toolClass); + registry.put(toolClass.getSimpleName(), toolClass); } - private void registerAdkToolsetClass(@Nonnull Class toolsetClass) { + private void registerAdkToolsetClass(Class toolsetClass) { registry.put(toolsetClass.getName(), toolsetClass); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.tools." + toolsetClass.getSimpleName(), toolsetClass); diff --git a/core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java b/core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java new file mode 100644 index 000000000..455b2cf95 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java @@ -0,0 +1,75 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.utils; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.LoopAgent; +import com.google.adk.agents.ParallelAgent; +import com.google.adk.agents.SequentialAgent; +import com.google.adk.tools.AgentTool; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ExampleTool; +import com.google.adk.tools.ExitLoopTool; +import com.google.adk.tools.LoadArtifactsTool; +import com.google.adk.tools.LongRunningFunctionTool; +import com.google.adk.tools.UrlContextTool; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Provides ADK components that are part of core. */ +public class CoreAdkComponentProvider implements AdkComponentProvider { + + /** + * Returns agent classes for {@link LlmAgent}, {@link LoopAgent}, {@link ParallelAgent} and {@link + * SequentialAgent}. + * + * @return a list of agent classes. + */ + @Override + public List> getAgentClasses() { + return Arrays.asList( + LlmAgent.class, LoopAgent.class, ParallelAgent.class, SequentialAgent.class); + } + + /** + * Returns tool classes for {@link AgentTool}, {@link LongRunningFunctionTool} and {@link + * ExampleTool}. + * + * @return a list of tool classes. + */ + @Override + public List> getToolClasses() { + return Arrays.asList(AgentTool.class, LongRunningFunctionTool.class, ExampleTool.class); + } + + /** + * Returns tool instances for {@link LoadArtifactsTool}, {@link ExitLoopTool} and {@link + * UrlContextTool}. + * + * @return a map of tool instances. + */ + @Override + public Map getToolInstances() { + Map toolInstances = new HashMap<>(); + toolInstances.put("load_artifacts", LoadArtifactsTool.INSTANCE); + toolInstances.put("exit_loop", ExitLoopTool.INSTANCE); + toolInstances.put("url_context", UrlContextTool.INSTANCE); + return toolInstances; + } +} diff --git a/core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider b/core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider new file mode 100644 index 000000000..795480cc8 --- /dev/null +++ b/core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider @@ -0,0 +1 @@ +com.google.adk.utils.AdditionalAdkComponentProvider From 59e87d319887c588a1ed7d4ca247cd31dffba2c6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 2 Feb 2026 20:44:40 -0800 Subject: [PATCH 15/63] feat: Adding a new `ArtifactService.saveAndReloadArtifact()` method The `saveAndReloadArtifact()` enables a save without a second i/o call just to get the full file path. PiperOrigin-RevId: 864659692 --- .../adk/artifacts/BaseArtifactService.java | 22 ++++ .../adk/artifacts/GcsArtifactService.java | 103 +++++++++++++----- .../artifacts/InMemoryArtifactService.java | 10 ++ .../adk/artifacts/GcsArtifactServiceTest.java | 37 +++++++ .../InMemoryArtifactServiceTest.java | 82 ++++++++++++++ 5 files changed, 226 insertions(+), 28 deletions(-) create mode 100644 core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index 847e88dd9..b6a3cee23 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -39,6 +39,28 @@ public interface BaseArtifactService { Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact); + /** + * Saves an artifact and returns it with fileData if available. + * + *

Implementations should override this default method for efficiency, as the default performs + * two I/O operations (save then load). + * + * @param appName the app name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the filename + * @param artifact the artifact to save + * @return the saved artifact with fileData if available. + */ + default Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMap( + version -> + loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) + .toSingle()); + } + /** * Gets an artifact. * diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index 1bfef8cf8..b9bc49a02 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -18,6 +18,7 @@ import static java.util.Collections.max; +import com.google.auto.value.AutoValue; import com.google.cloud.storage.Blob; import com.google.cloud.storage.BlobId; import com.google.cloud.storage.BlobInfo; @@ -27,6 +28,7 @@ import com.google.common.base.Splitter; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.genai.types.FileData; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; @@ -108,34 +110,8 @@ private String getBlobName( @Override public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { - return listVersions(appName, userId, sessionId, filename) - .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) - .map( - nextVersion -> { - String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); - BlobId blobId = BlobId.of(bucketName, blobName); - - BlobInfo blobInfo = - BlobInfo.newBuilder(blobId) - .setContentType(artifact.inlineData().get().mimeType().orElse(null)) - .build(); - - try { - byte[] dataToSave = - artifact - .inlineData() - .get() - .data() - .orElseThrow( - () -> - new IllegalArgumentException( - "Saveable artifact data must be non-empty.")); - storageClient.create(blobInfo, dataToSave); - return nextVersion; - } catch (StorageException e) { - throw new VerifyException("Failed to save artifact to GCS", e); - } - }); + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .map(SaveResult::version); } /** @@ -275,4 +251,75 @@ public Single> listVersions( return Single.just(ImmutableList.of()); } } + + @Override + public Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .flatMap( + blob -> { + Blob savedBlob = blob.blob(); + String resultMimeType = + Optional.ofNullable(savedBlob.getContentType()) + .or( + () -> + artifact.inlineData().flatMap(com.google.genai.types.Blob::mimeType)) + .orElse("application/octet-stream"); + return Single.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri("gs://" + savedBlob.getBucket() + "/" + savedBlob.getName()) + .mimeType(resultMimeType) + .build()) + .build()); + }); + } + + @AutoValue + abstract static class SaveResult { + static SaveResult create(Blob blob, int version) { + return new AutoValue_GcsArtifactService_SaveResult(blob, version); + } + + abstract Blob blob(); + + abstract int version(); + } + + private Single saveArtifactAndReturnBlob( + String appName, String userId, String sessionId, String filename, Part artifact) { + return listVersions(appName, userId, sessionId, filename) + .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) + .map( + nextVersion -> { + if (artifact.inlineData().isEmpty()) { + throw new IllegalArgumentException("Saveable artifact must have inline data."); + } + + String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); + BlobId blobId = BlobId.of(bucketName, blobName); + + BlobInfo blobInfo = + BlobInfo.newBuilder(blobId) + .setContentType(artifact.inlineData().get().mimeType().orElse(null)) + .build(); + + try { + byte[] dataToSave = + artifact + .inlineData() + .get() + .data() + .orElseThrow( + () -> + new IllegalArgumentException( + "Saveable artifact data must be non-empty.")); + Blob blob = storageClient.create(blobInfo, dataToSave); + return SaveResult.create(blob, nextVersion); + } catch (StorageException e) { + throw new VerifyException("Failed to save artifact to GCS", e); + } + }); + } } diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 27b85136d..5808f7083 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -125,6 +125,16 @@ public Single> listVersions( return Single.just(IntStream.range(0, size).boxed().collect(toImmutableList())); } + @Override + public Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMap( + version -> + loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) + .toSingle()); + } + private Map> getArtifactsMap(String appName, String userId, String sessionId) { return artifacts .computeIfAbsent(appName, unused -> new HashMap<>()) diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 1df66c36d..40493bf3a 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -76,6 +77,7 @@ private Blob mockBlob(String name, String contentType, byte[] content) { when(blob.exists()).thenReturn(true); BlobId blobId = BlobId.of(BUCKET_NAME, name); when(blob.getBlobId()).thenReturn(blobId); + when(blob.getBucket()).thenReturn(BUCKET_NAME); return blob; } @@ -89,6 +91,8 @@ public void save_firstVersion_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -109,6 +113,8 @@ public void save_subsequentVersion_savesCorrectly() { Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1}); when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0)); + Blob savedBlob = mockBlob(expectedBlobNameV1, "image/png", new byte[] {4, 5}); + when(mockStorage.create(eq(expectedBlobInfoV1), eq(new byte[] {4, 5}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -126,6 +132,8 @@ public void save_userNamespace_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/json").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/json", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME, artifact).blockingGet(); @@ -330,7 +338,36 @@ public void listVersions_noVersions_returnsEmptyList() { assertThat(versions).isEmpty(); } + @Test + public void saveAndReloadArtifact_savesAndReturnsFileData() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "application/octet-stream"); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + BlobId expectedBlobId = BlobId.of(BUCKET_NAME, expectedBlobName); + BlobInfo expectedBlobInfo = + BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); + + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + + assertThat(result).isPresent(); + assertThat(result.get().fileData()).isPresent(); + assertThat(result.get().fileData().get().fileUri()) + .hasValue("gs://" + BUCKET_NAME + "/" + expectedBlobName); + assertThat(result.get().fileData().get().mimeType()).hasValue("application/octet-stream"); + verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3})); + } + private static Optional asOptional(Maybe maybe) { return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); } + + private static Optional asOptional(Single single) { + return Optional.of(single.blockingGet()); + } } diff --git a/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java new file mode 100644 index 000000000..4cb493277 --- /dev/null +++ b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.artifacts; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InMemoryArtifactService}. */ +@RunWith(JUnit4.class) +public class InMemoryArtifactServiceTest { + + private static final String APP_NAME = "test-app"; + private static final String USER_ID = "test-user"; + private static final String SESSION_ID = "test-session"; + private static final String FILENAME = "test-file.txt"; + + private InMemoryArtifactService service; + + @Before + public void setUp() { + service = new InMemoryArtifactService(); + } + + @Test + public void saveArtifact_savesAndReturnsVersion() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + int version = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); + assertThat(version).isEqualTo(0); + } + + @Test + public void loadArtifact_loadsLatest() { + Part artifact1 = Part.fromBytes(new byte[] {1}, "text/plain"); + Part artifact2 = Part.fromBytes(new byte[] {1, 2}, "text/plain"); + var unused1 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact1).blockingGet(); + var unused2 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact2).blockingGet(); + Optional result = + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + assertThat(result).hasValue(artifact2); + } + + @Test + public void saveAndReloadArtifact_reloadsArtifact() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + assertThat(result).hasValue(artifact); + } + + private static Optional asOptional(Maybe maybe) { + return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); + } + + private static Optional asOptional(Single single) { + return Optional.of(single.blockingGet()); + } +} From efe58d6e0e5e0ff35d39e56bcb0f57cc6ccc7ccc Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Feb 2026 01:15:01 -0800 Subject: [PATCH 16/63] feat: Introduce TailRetentionEventCompactor to compact and retain the tail of the event stream Provide a way to manage the size of an event stream Specifically, it: * Keeps the retentionSize most recent events raw. * Compacts all events that never compacted and older than the retained tail, including the most recent compaction events, into a new summary event. * Appends this new summary event to the end of the event stream. PiperOrigin-RevId: 864748009 --- .../TailRetentionEventCompactor.java | 193 ++++++++++++++ .../TailRetentionEventCompactorTest.java | 246 ++++++++++++++++++ 2 files changed, 439 insertions(+) create mode 100644 core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java create mode 100644 core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java diff --git a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java new file mode 100644 index 000000000..c13a49cc1 --- /dev/null +++ b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java @@ -0,0 +1,193 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.summarizer; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.adk.events.Event; +import com.google.adk.events.EventCompaction; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class performs event compaction by retaining the tail of the event stream. + * + *

+ * + *

This compactor produces a rolling summary. Each new compaction event includes the content of + * the previous compaction event (if any) along with new events, effectively superseding all prior + * compactions. + */ +public final class TailRetentionEventCompactor implements EventCompactor { + + private static final Logger logger = LoggerFactory.getLogger(TailRetentionEventCompactor.class); + + private final BaseEventSummarizer summarizer; + private final int retentionSize; + + public TailRetentionEventCompactor(BaseEventSummarizer summarizer, int retentionSize) { + this.summarizer = summarizer; + this.retentionSize = retentionSize; + } + + @Override + public Completable compact(Session session, BaseSessionService sessionService) { + checkArgument(summarizer != null, "Missing BaseEventSummarizer for event compaction"); + logger.debug("Running tail retention event compaction for session {}", session.id()); + + return Completable.fromMaybe( + getCompactionEvents(session.events()) + .flatMap(summarizer::summarizeEvents) + .flatMapSingle(e -> sessionService.appendEvent(session, e))); + } + + /** + * Identifies events to be compacted based on the tail retention strategy. + * + *

This method iterates backwards through the event list to find the most recent compaction + * event (if any) and collects all uncompacted events that occurred after the range covered by + * that compaction. It then applies the retention policy, excluding the most recent {@code + * retentionSize} events from being compacted. + * + *

Basic Scenario: + * + *

+ * + *

Advanced Scenario (Handling Gaps): + * + *

Consider an edge case where retention size is 3. Event E4 appears before the last compaction + * event (C2) and even the one prior (C1), but remains uncompacted and must be included in the + * third compaction (C3). + * + *

+ * + *

Execution with Retention = 3: + * + *

    + *
  1. The method scans backward: E9, C2, E7, E6, C1, E4... + *
  2. C2 is identified as the most recent compaction event (end timestamp T=3). + *
  3. E9, E7, E6 are collected as they are newer than T=3. + *
  4. C1 is ignored as we only care about the boundary set by the latest compaction. + *
  5. E4 (T=4) is collected because it is newer than T=3. + *
  6. Scanning stops at E3 as it is covered by C2 (timestamp <= T=3). + *
  7. The initial list of events to summarize: [E9, E7, E6, E4]. + *
  8. After appending the compaction event C2, the list becomes: [E9, E7, E6, E4, C2] + *
  9. Reversing the list: [C2, E4, E6, E7, E9]. + *
  10. Applying retention (keep last 3): E6, E7, E9 are removed from the summary list. + *
  11. Final Output: {@code [C2, E4]}. E4 and the previous summary C2 will be compacted + * together. The new compaction event will cover the range from the start of the included + * compaction event (C2, T=1) to the end of the new events (E4, T=4). + *
+ */ + private Maybe> getCompactionEvents(List events) { + long compactionEndTimestamp = Long.MIN_VALUE; + Event lastCompactionEvent = null; + List eventsToSummarize = new ArrayList<>(); + + // Iterate backwards from the end of the window to summarize. + // We use a single loop to: + // 1. Collect all raw events that happened after the latest compaction. + // 2. Identify the latest compaction event to establish the stop condition (boundary). + ListIterator iter = events.listIterator(events.size()); + while (iter.hasPrevious()) { + Event event = iter.previous(); + + if (!isCompactEvent(event)) { + // Only include events that are strictly after the last compaction range. + if (event.timestamp() > compactionEndTimestamp) { + eventsToSummarize.add(event); + continue; + } else { + // Exit early if we have reached the last event of last compaction range. + break; + } + } + + EventCompaction compaction = event.actions().compaction().orElse(null); + // We use the most recent compaction event to define the time boundary. Any subsequent (older) + // compaction events are ignored. + if (lastCompactionEvent == null) { + compactionEndTimestamp = compaction.endTimestamp(); + lastCompactionEvent = event; + } + } + + // If there are not enough events to summarize, we can return early. + if (eventsToSummarize.size() <= retentionSize) { + return Maybe.empty(); + } + + // Add the last compaction event to the list of events to summarize. + // This is to ensure that the last compaction event is included in the summary. + if (lastCompactionEvent != null) { + EventCompaction compaction = lastCompactionEvent.actions().compaction().get(); + eventsToSummarize.add( + lastCompactionEvent.toBuilder() + .content(compaction.compactedContent()) + // Use the start timestamp so that the new summary covers the entire range. + .timestamp(compaction.startTimestamp()) + .build()); + } + + Collections.reverse(eventsToSummarize); + + // Apply retention: keep the most recent 'retentionSize' events out of the summary. + // We do this by removing them from the list of events to be summarized. + eventsToSummarize + .subList(eventsToSummarize.size() - retentionSize, eventsToSummarize.size()) + .clear(); + return Maybe.just(eventsToSummarize); + } + + private static boolean isCompactEvent(Event event) { + return event.actions() != null && event.actions().compaction().isPresent(); + } +} diff --git a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java new file mode 100644 index 000000000..b4a6c3474 --- /dev/null +++ b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java @@ -0,0 +1,246 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.summarizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.events.EventCompaction; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.List; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class TailRetentionEventCompactorTest { + + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + @Mock private BaseSessionService mockSessionService; + @Mock private BaseEventSummarizer mockSummarizer; + @Captor private ArgumentCaptor> eventListCaptor; + + @Test + public void compact_notEnoughEvents_doesNothing() { + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), createEvent(2, "Event2"), createEvent(3, "Event3")); + Session session = Session.builder("id").events(events).build(); + + // Retention size 5 > 3 events + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compact_respectRetentionSize_summarizesCorrectEvents() { + // Retention size is 2. + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), createEvent(2, "Retain1"), createEvent(3, "Retain2")); + Session session = Session.builder("id").events(events).build(); + Event compactedEvent = createCompactedEvent(1, 1, "Summary", 4); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); + when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); + + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer).summarizeEvents(eventListCaptor.capture()); + List summarizedEvents = eventListCaptor.getValue(); + assertThat(summarizedEvents).hasSize(1); + assertThat(getPromptText(summarizedEvents.get(0))).isEqualTo("Event1"); + + verify(mockSessionService).appendEvent(eq(session), eq(compactedEvent)); + } + + @Test + public void compact_withRetainedEventsPhysicallyBeforeCompaction_includesThem() { + // Simulating the user's specific case with retention size 1: + // "event1, event2, event3, compaction1-2 ... event3 is retained so it is before compaction + // event" + // + // Timeline: + // T=1: E1 + // T=2: E2 + // T=3: E3 + // T=4: C1 (Covers T=1 to T=2). + // + // Note: C1 was inserted *after* E3 in the list. + // List order: E1, E2, E3, C1. + // + // If we have more events: + // T=5: E5 + // T=6: E6 + // + // Retained: E6. + // Summary Input: C1, E3, E5. (E1, E2 covered by C1). + ImmutableList events = + ImmutableList.of( + createEvent(1, "E1"), + createEvent(2, "E2"), + createEvent(3, "E3"), + createCompactedEvent( + /* startTimestamp= */ 1, /* endTimestamp= */ 2, "C1", /* eventTimestamp= */ 4), + createEvent(5, "E5"), + createEvent(6, "E6")); + Session session = Session.builder("id").events(events).build(); + Event compactedEvent = createCompactedEvent(1, 5, "Summary C1-E5", 7); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); + when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); + + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer).summarizeEvents(eventListCaptor.capture()); + List summarizedEvents = eventListCaptor.getValue(); + assertThat(summarizedEvents).hasSize(3); + + // Check first event is reconstructed C1 + Event reconstructedC1 = summarizedEvents.get(0); + assertThat(getPromptText(reconstructedC1)).isEqualTo("C1"); + // Verify timestamp is reset to startTimestamp (1) + assertThat(reconstructedC1.timestamp()).isEqualTo(1); + + // Check second event is E3 + Event e3 = summarizedEvents.get(1); + assertThat(getPromptText(e3)).isEqualTo("E3"); + assertThat(e3.timestamp()).isEqualTo(3); + + // Check third event is E5 + Event e5 = summarizedEvents.get(2); + assertThat(getPromptText(e5)).isEqualTo("E5"); + assertThat(e5.timestamp()).isEqualTo(5); + } + + @Test + public void compact_withMultipleCompactionEvents_respectsCompactionBoundary() { + // T=1: E1 + // T=2: E2, retained by C1 + // T=3: E3, retained by C1 + // T=4: E4, retained by C1 and C2 + // T=5: C1 (Covers T=1) + // T=6: E6, retained by C2 + // T=7: E7, retained by C2 + // T=8: C2 (Covers T=1 to T=3) since it covers C1 which starts at T=1. + // T=9: E9 + + // Retention = 3. + // Expected to summarize: C2, E4. (E1 covered by C1 - ignored, E2, E3 covered by C2). + // E6, E7, E9 are retained. + + ImmutableList events = + ImmutableList.of( + createEvent(1, "E1"), + createEvent(2, "E2"), + createEvent(3, "E3"), + createEvent(4, "E4"), + createCompactedEvent( + /* startTimestamp= */ 1, /* endTimestamp= */ 1, "C1", /* eventTimestamp= */ 5), + createEvent(6, "E6"), + createEvent(7, "E7"), + createCompactedEvent( + /* startTimestamp= */ 1, /* endTimestamp= */ 3, "C2", /* eventTimestamp= */ 8), + createEvent(9, "E9")); + Session session = Session.builder("id").events(events).build(); + Event compactedEvent = createCompactedEvent(1, 4, "Summary C2-E4", 10); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); + when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); + + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer).summarizeEvents(eventListCaptor.capture()); + List summarizedEvents = eventListCaptor.getValue(); + + assertThat(summarizedEvents).hasSize(2); + + // Check first event is reconstructed C2 + Event reconstructedC2 = summarizedEvents.get(0); + assertThat(getPromptText(reconstructedC2)).isEqualTo("C2"); + // Verify timestamp is reset to startTimestamp (1), not event timestamp (8) or end timestamp (3) + assertThat(reconstructedC2.timestamp()).isEqualTo(1); + + // Check second event is E4 + Event e4 = summarizedEvents.get(1); + assertThat(e4.timestamp()).isEqualTo(4); + } + + private static Event createEvent(long timestamp, String text) { + return Event.builder() + .timestamp(timestamp) + .content(Content.builder().parts(Part.fromText(text)).build()) + .build(); + } + + private static String getPromptText(Event event) { + return event + .content() + .flatMap(Content::parts) + .flatMap(parts -> parts.stream().findFirst()) + .flatMap(Part::text) + .orElseThrow(); + } + + private Event createCompactedEvent( + long startTimestamp, long endTimestamp, String content, long eventTimestamp) { + return Event.builder() + .timestamp(eventTimestamp) + .actions( + EventActions.builder() + .compaction( + EventCompaction.builder() + .startTimestamp(startTimestamp) + .endTimestamp(endTimestamp) + .compactedContent( + Content.builder() + .role("model") + .parts(Part.builder().text(content).build()) + .build()) + .build()) + .build()) + .build(); + } +} From 0d0d514e72fe3643253bbc088fe59fed3841e9b4 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Feb 2026 08:06:53 -0800 Subject: [PATCH 17/63] refactor: refactors parts of the ADK codebase to improve null safety and consistency This CL refactors parts of the ADK codebase to improve null safety and consistency. The main changes include: 1. **`BaseAgent`**: * `beforeAgentCallback` and `afterAgentCallback` fields and their accessors now use `ImmutableList` (defaulting to empty) instead of `Optional`. * `findAgent` and `findSubAgent` now return `Optional`, with `findSubAgent` being reimplemented using Java Streams. 2. **`BaseAgentConfig`**: Getters for `subAgents`, `beforeAgentCallbacks`, and `afterAgentCallbacks` now return an empty list if the underlying field is null. 3. **`CallbackUtil`**: `getBeforeAgentCallbacks` and `getAfterAgentCallbacks` return `ImmutableList.of()` instead of `null` for null inputs. 4. **`LlmAgent`**: The `codeExecutor()` method now returns `Optional`. These changes necessitate updates in `BaseLlmFlow`, `CodeExecution`, and `Runner` to handle the new `Optional` return types. PiperOrigin-RevId: 864883202 --- .../java/com/google/adk/agents/BaseAgent.java | 63 +++++++------ .../google/adk/agents/BaseAgentConfig.java | 25 +++-- .../com/google/adk/agents/CallbackUtil.java | 92 +++++++++---------- .../java/com/google/adk/agents/LlmAgent.java | 5 +- .../adk/flows/llmflows/BaseLlmFlow.java | 12 +-- .../adk/flows/llmflows/CodeExecution.java | 55 ++++++----- .../java/com/google/adk/runner/Runner.java | 8 +- .../com/google/adk/agents/BaseAgentTest.java | 8 +- .../adk/agents/ConfigAgentUtilsTest.java | 6 +- 9 files changed, 143 insertions(+), 131 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 948d5ebac..646072537 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -57,10 +57,10 @@ public abstract class BaseAgent { */ private BaseAgent parentAgent; - private final List subAgents; + private final ImmutableList subAgents; - private final Optional> beforeAgentCallback; - private final Optional> afterAgentCallback; + private final ImmutableList beforeAgentCallback; + private final ImmutableList afterAgentCallback; /** * Creates a new BaseAgent. @@ -82,9 +82,13 @@ public BaseAgent( this.name = name; this.description = description; this.parentAgent = null; - this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); - this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); + this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents); + this.beforeAgentCallback = + beforeAgentCallback == null + ? ImmutableList.of() + : ImmutableList.copyOf(beforeAgentCallback); + this.afterAgentCallback = + afterAgentCallback == null ? ImmutableList.of() : ImmutableList.copyOf(afterAgentCallback); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -144,38 +148,38 @@ public BaseAgent rootAgent() { /** * Finds an agent (this or descendant) by name. * - * @return the agent or descendant with the given name, or {@code null} if not found. + * @return an {@link Optional} containing the agent or descendant with the given name, or {@link + * Optional#empty()} if not found. */ - public BaseAgent findAgent(String name) { + public Optional findAgent(String name) { if (this.name().equals(name)) { - return this; + return Optional.of(this); } return findSubAgent(name); } - /** Recursively search sub agent by name. */ - public @Nullable BaseAgent findSubAgent(String name) { - for (BaseAgent subAgent : subAgents) { - if (subAgent.name().equals(name)) { - return subAgent; - } - BaseAgent result = subAgent.findSubAgent(name); - if (result != null) { - return result; - } - } - return null; + /** + * Recursively search sub agent by name. + * + * @return an {@link Optional} containing the sub agent with the given name, or {@link + * Optional#empty()} if not found. + */ + public Optional findSubAgent(String name) { + return subAgents.stream() + .map(subAgent -> subAgent.findAgent(name)) + .flatMap(Optional::stream) + .findFirst(); } public List subAgents() { return subAgents; } - public Optional> beforeAgentCallback() { + public ImmutableList beforeAgentCallback() { return beforeAgentCallback; } - public Optional> afterAgentCallback() { + public ImmutableList afterAgentCallback() { return afterAgentCallback; } @@ -184,8 +188,8 @@ public Optional> afterAgentCallback() { * *

This method is only for use by Agent Development Kit. */ - public List canonicalBeforeAgentCallbacks() { - return beforeAgentCallback.orElse(ImmutableList.of()); + public ImmutableList canonicalBeforeAgentCallbacks() { + return beforeAgentCallback; } /** @@ -193,8 +197,8 @@ public List canonicalBeforeAgentCallbacks() { * *

This method is only for use by Agent Development Kit. */ - public List canonicalAfterAgentCallbacks() { - return afterAgentCallback.orElse(ImmutableList.of()); + public ImmutableList canonicalAfterAgentCallbacks() { + return afterAgentCallback; } /** @@ -239,8 +243,7 @@ public Flowable runAsync(InvocationContext parentContext) { () -> callCallback( beforeCallbacksToFunctions( - invocationContext.pluginManager(), - beforeAgentCallback.orElse(ImmutableList.of())), + invocationContext.pluginManager(), beforeAgentCallback), invocationContext) .flatMapPublisher( beforeEventOpt -> { @@ -257,7 +260,7 @@ public Flowable runAsync(InvocationContext parentContext) { callCallback( afterCallbacksToFunctions( invocationContext.pluginManager(), - afterAgentCallback.orElse(ImmutableList.of())), + afterAgentCallback), invocationContext) .flatMapPublisher(Flowable::fromOptional)); diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java b/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java index e38895afb..40ed58937 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java @@ -16,6 +16,7 @@ package com.google.adk.agents; +import com.google.common.collect.ImmutableList; import java.util.List; /** @@ -27,11 +28,11 @@ public class BaseAgentConfig { private String name; private String description = ""; private String agentClass; - private List subAgents; + private ImmutableList subAgents = ImmutableList.of(); // Callback configuration (names resolved via ComponentRegistry) - private List beforeAgentCallbacks; - private List afterAgentCallbacks; + private ImmutableList beforeAgentCallbacks = ImmutableList.of(); + private ImmutableList afterAgentCallbacks = ImmutableList.of(); /** Reference to a callback stored in the ComponentRegistry. */ public static class CallbackRef { @@ -131,27 +132,33 @@ public String agentClass() { return agentClass; } - public List subAgents() { + public ImmutableList subAgents() { return subAgents; } public void setSubAgents(List subAgents) { - this.subAgents = subAgents; + this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents); } - public List beforeAgentCallbacks() { + public ImmutableList beforeAgentCallbacks() { return beforeAgentCallbacks; } public void setBeforeAgentCallbacks(List beforeAgentCallbacks) { - this.beforeAgentCallbacks = beforeAgentCallbacks; + this.beforeAgentCallbacks = + beforeAgentCallbacks == null + ? ImmutableList.of() + : ImmutableList.copyOf(beforeAgentCallbacks); } - public List afterAgentCallbacks() { + public ImmutableList afterAgentCallbacks() { return afterAgentCallbacks; } public void setAfterAgentCallbacks(List afterAgentCallbacks) { - this.afterAgentCallbacks = afterAgentCallbacks; + this.afterAgentCallbacks = + afterAgentCallbacks == null + ? ImmutableList.of() + : ImmutableList.copyOf(afterAgentCallbacks); } } diff --git a/core/src/main/java/com/google/adk/agents/CallbackUtil.java b/core/src/main/java/com/google/adk/agents/CallbackUtil.java index 728fd1d5a..11740ae9c 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackUtil.java +++ b/core/src/main/java/com/google/adk/agents/CallbackUtil.java @@ -26,7 +26,8 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import org.jspecify.annotations.Nullable; +import java.util.function.Function; +import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,65 +38,62 @@ public final class CallbackUtil { /** * Normalizes before-agent callbacks. * - * @param beforeAgentCallback Callback list (sync or async). - * @return normalized async callbacks, or null if input is null. + * @param beforeAgentCallbacks Callback list (sync or async). + * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static @Nullable ImmutableList getBeforeAgentCallbacks( - List beforeAgentCallback) { - if (beforeAgentCallback == null) { - return null; - } else if (beforeAgentCallback.isEmpty()) { - return ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeAgentCallbackBase callback : beforeAgentCallback) { - if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { - builder.add(beforeAgentCallbackInstance); - } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { - builder.add( - (callbackContext) -> - Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext))); - } else { - logger.warn( - "Invalid beforeAgentCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - return builder.build(); - } + public static ImmutableList getBeforeAgentCallbacks( + List beforeAgentCallbacks) { + return getCallbacks( + beforeAgentCallbacks, + BeforeAgentCallback.class, + BeforeAgentCallbackSync.class, + sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))), + "beforeAgentCallbacks"); } /** * Normalizes after-agent callbacks. * * @param afterAgentCallback Callback list (sync or async). - * @return normalized async callbacks, or null if input is null. + * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static @Nullable ImmutableList getAfterAgentCallbacks( + public static ImmutableList getAfterAgentCallbacks( List afterAgentCallback) { - if (afterAgentCallback == null) { - return null; - } else if (afterAgentCallback.isEmpty()) { + return getCallbacks( + afterAgentCallback, + AfterAgentCallback.class, + AfterAgentCallbackSync.class, + sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))), + "afterAgentCallback"); + } + + private static ImmutableList getCallbacks( + List callbacks, + Class asyncClass, + Class syncClass, + Function converter, + String callbackTypeForLogging) { + if (callbacks == null) { return ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (AfterAgentCallbackBase callback : afterAgentCallback) { - if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { - builder.add(afterAgentCallbackInstance); - } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { - builder.add( - (callbackContext) -> - Maybe.fromOptional(afterAgentCallbackSyncInstance.call(callbackContext))); - } else { - logger.warn( - "Invalid afterAgentCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - return builder.build(); } + return callbacks.stream() + .flatMap( + callback -> { + if (asyncClass.isInstance(callback)) { + return Stream.of(asyncClass.cast(callback)); + } else if (syncClass.isInstance(callback)) { + return Stream.of(converter.apply(syncClass.cast(callback))); + } else { + logger.warn( + "Invalid {} callback type: {}. Ignoring this callback.", + callbackTypeForLogging, + callback.getClass().getName()); + return Stream.empty(); + } + }) + .collect(ImmutableList.toImmutableList()); } private CallbackUtil() {} diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 1f16d7c00..87967bb6d 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -935,9 +935,8 @@ public Optional outputKey() { return outputKey; } - @Nullable - public BaseCodeExecutor codeExecutor() { - return codeExecutor.orElse(null); + public Optional codeExecutor() { + return codeExecutor; } public Model resolvedModel() { diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 46b3f1952..cfbadb9fe 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -388,15 +388,15 @@ private Flowable runOneStep(InvocationContext context) { String agentToTransfer = event.actions().transferToAgent().get(); logger.debug("Transferring to agent: {}", agentToTransfer); BaseAgent rootAgent = context.agent().rootAgent(); - BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer); - if (nextAgent == null) { + Optional nextAgent = rootAgent.findAgent(agentToTransfer); + if (nextAgent.isEmpty()) { String errorMsg = "Agent not found for transfer: " + agentToTransfer; logger.error(errorMsg); return postProcessedEvents.concatWith( Flowable.error(new IllegalStateException(errorMsg))); } return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.runAsync(context))); + Flowable.defer(() -> nextAgent.get().runAsync(context))); } return postProcessedEvents; }); @@ -574,14 +574,14 @@ public void onError(Throwable e) { Flowable events = Flowable.just(event); if (event.actions().transferToAgent().isPresent()) { BaseAgent rootAgent = invocationContext.agent().rootAgent(); - BaseAgent nextAgent = + Optional nextAgent = rootAgent.findAgent(event.actions().transferToAgent().get()); - if (nextAgent == null) { + if (nextAgent.isEmpty()) { throw new IllegalStateException( "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.runLive(invocationContext); + nextAgent.get().runLive(invocationContext); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index 64d95cef3..bb1789609 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -22,7 +22,6 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; -import com.google.adk.codeexecutors.BaseCodeExecutor; import com.google.adk.codeexecutors.BuiltInCodeExecutor; import com.google.adk.codeexecutors.CodeExecutionUtils; import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput; @@ -108,12 +107,12 @@ private static class CodeExecutionRequestProcessor implements RequestProcessor { public Single processRequest( InvocationContext invocationContext, LlmRequest llmRequest) { if (!(invocationContext.agent() instanceof LlmAgent llmAgent) - || llmAgent.codeExecutor() == null) { + || llmAgent.codeExecutor().isEmpty()) { return Single.just( RequestProcessor.RequestProcessingResult.create(llmRequest, ImmutableList.of())); } - if (llmAgent.codeExecutor() instanceof BuiltInCodeExecutor builtInCodeExecutor) { + if (llmAgent.codeExecutor().get() instanceof BuiltInCodeExecutor builtInCodeExecutor) { var llmRequestBuilder = llmRequest.toBuilder(); builtInCodeExecutor.processLlmRequest(llmRequestBuilder); LlmRequest updatedLlmRequest = llmRequestBuilder.build(); @@ -124,21 +123,27 @@ public Single processRequest( Flowable preprocessorEvents = runPreProcessor(invocationContext, llmRequest); // Convert the code execution parts to text parts. - if (llmAgent.codeExecutor() != null) { - BaseCodeExecutor baseCodeExecutor = llmAgent.codeExecutor(); - List updatedContents = new ArrayList<>(); - for (Content content : llmRequest.contents()) { - List delimiters = - !baseCodeExecutor.codeBlockDelimiters().isEmpty() - ? baseCodeExecutor.codeBlockDelimiters().get(0) - : ImmutableList.of("", ""); - updatedContents.add( - CodeExecutionUtils.convertCodeExecutionParts( - content, delimiters, baseCodeExecutor.executionResultDelimiters())); - } - llmRequest = llmRequest.toBuilder().contents(updatedContents).build(); - } - final LlmRequest finalLlmRequest = llmRequest; + final LlmRequest finalLlmRequest = + llmAgent + .codeExecutor() + .map( + baseCodeExecutor -> { + List delimiters = + !baseCodeExecutor.codeBlockDelimiters().isEmpty() + ? baseCodeExecutor.codeBlockDelimiters().get(0) + : ImmutableList.of("", ""); + ImmutableList updatedContents = + llmRequest.contents().stream() + .map( + content -> + CodeExecutionUtils.convertCodeExecutionParts( + content, + delimiters, + baseCodeExecutor.executionResultDelimiters())) + .collect(toImmutableList()); + return llmRequest.toBuilder().contents(updatedContents).build(); + }) + .orElse(llmRequest); return preprocessorEvents .toList() .map( @@ -173,10 +178,11 @@ private static Flowable runPreProcessor( return Flowable.empty(); } - var codeExecutor = llmAgent.codeExecutor(); - if (codeExecutor == null) { + var codeExecutorOptional = llmAgent.codeExecutor(); + if (codeExecutorOptional.isEmpty()) { return Flowable.empty(); } + var codeExecutor = codeExecutorOptional.get(); if (codeExecutor instanceof BuiltInCodeExecutor) { return Flowable.empty(); @@ -268,10 +274,11 @@ private static Flowable runPostProcessor( if (!(invocationContext.agent() instanceof LlmAgent llmAgent)) { return Flowable.empty(); } - var codeExecutor = llmAgent.codeExecutor(); - if (codeExecutor == null) { + var codeExecutorOptional = llmAgent.codeExecutor(); + if (codeExecutorOptional.isEmpty()) { return Flowable.empty(); } + var codeExecutor = codeExecutorOptional.get(); if (llmResponse.content().isEmpty()) { return Flowable.empty(); } @@ -387,8 +394,8 @@ private static List extractAndReplaceInlineFiles( private static Optional getOrSetExecutionId( InvocationContext invocationContext, CodeExecutorContext codeExecutorContext) { if (!(invocationContext.agent() instanceof LlmAgent llmAgent) - || llmAgent.codeExecutor() == null - || !llmAgent.codeExecutor().stateful()) { + || llmAgent.codeExecutor().isEmpty() + || !llmAgent.codeExecutor().get().stateful()) { return Optional.empty(); } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 51696f1af..5c275ab56 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -768,14 +768,14 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { return rootAgent; } - BaseAgent agent = rootAgent.findSubAgent(author); + Optional agent = rootAgent.findSubAgent(author); - if (agent == null) { + if (agent.isEmpty()) { continue; } - if (this.isTransferableAcrossAgentTree(agent)) { - return agent; + if (this.isTransferableAcrossAgentTree(agent.get())) { + return agent.get(); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 345436826..2ae53d0e1 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -59,10 +59,10 @@ public void findAgent_returnsCorrectAgent() { TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, null, ImmutableList.of(subAgent), null, null); - assertThat(agent.findAgent("subSubAgent")).isEqualTo(subSubAgent); - assertThat(agent.findAgent("subAgent")).isEqualTo(subAgent); - assertThat(agent.findAgent(TEST_AGENT_NAME)).isEqualTo(agent); - assertThat(agent.findAgent("nonExistent")).isNull(); + assertThat(agent.findAgent("subSubAgent")).hasValue(subSubAgent); + assertThat(agent.findAgent("subAgent")).hasValue(subAgent); + assertThat(agent.findAgent(TEST_AGENT_NAME)).hasValue(agent); + assertThat(agent.findAgent("nonExistent")).isEmpty(); } @Test diff --git a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java index 11e07a094..4f6ea6104 100644 --- a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java +++ b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java @@ -1209,10 +1209,8 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() assertThat(agent).isInstanceOf(LlmAgent.class); LlmAgent llm = (LlmAgent) agent; - assertThat(agent.beforeAgentCallback()).isPresent(); - assertThat(agent.beforeAgentCallback().get()).hasSize(2); - assertThat(agent.afterAgentCallback()).isPresent(); - assertThat(agent.afterAgentCallback().get()).hasSize(1); + assertThat(agent.beforeAgentCallback()).hasSize(2); + assertThat(agent.afterAgentCallback()).hasSize(1); assertThat(llm.beforeModelCallback()).hasSize(1); assertThat(llm.afterModelCallback()).hasSize(1); From 3c4420aecddc6f191bd64ab732e5bcadda6e82c3 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Feb 2026 10:26:12 -0800 Subject: [PATCH 18/63] refactor: BaseAgent: Apply Java style, fix a unit test, add a unit test PiperOrigin-RevId: 864938809 --- .../java/com/google/adk/agents/BaseAgent.java | 2 +- .../com/google/adk/agents/BaseAgentTest.java | 20 +++++++++ .../com/google/adk/testing/TestCallback.java | 41 +++++++++++++------ 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 646072537..255d59c4d 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -76,7 +76,7 @@ public abstract class BaseAgent { public BaseAgent( String name, String description, - List subAgents, + @Nullable List subAgents, @Nullable List beforeAgentCallback, @Nullable List afterAgentCallback) { this.name = name; diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 2ae53d0e1..dec66a77c 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -316,4 +316,24 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalBeforeAgentCallbacks()).containsExactly(bc); assertThat(agent.canonicalAfterAgentCallbacks()).containsExactly(ac); } + + @Test + public void runLive_invokesRunLiveImpl() { + var runLiveCallback = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("live_output")); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + /* beforeAgentCallbacks= */ ImmutableList.of(), + /* afterAgentCallbacks= */ ImmutableList.of(), + runLiveCallback.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(runLiveCallback.wasCalled()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 04f83ed9b..434d85e6f 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -102,63 +102,80 @@ public Supplier> asRunAsyncImplSupplier(String contentText) { return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText))); } + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + * with an event containing the given content. + */ + public Supplier> asRunLiveImplSupplier(Content content) { + return () -> + Flowable.defer( + () -> { + markAsCalled(); + return Flowable.just(Event.builder().content(content).build()); + }); + } + @SuppressWarnings("unchecked") // This cast is safe if T is Content. public BeforeAgentCallback asBeforeAgentCallback() { - return ctx -> (Maybe) callMaybe(); + return (unusedCtx) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Content. public BeforeAgentCallbackSync asBeforeAgentCallbackSync() { - return ctx -> (Optional) callOptional(); + return (unusedCtx) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is Content. public AfterAgentCallback asAfterAgentCallback() { - return ctx -> (Maybe) callMaybe(); + return (unusedCtx) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Content. public AfterAgentCallbackSync asAfterAgentCallbackSync() { - return ctx -> (Optional) callOptional(); + return (unusedCtx) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public BeforeModelCallback asBeforeModelCallback() { - return (ctx, req) -> (Maybe) callMaybe(); + return (unusedCtx, unusedReq) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public BeforeModelCallbackSync asBeforeModelCallbackSync() { - return (ctx, req) -> (Optional) callOptional(); + return (unusedCtx, unusedReq) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public AfterModelCallback asAfterModelCallback() { - return (ctx, res) -> (Maybe) callMaybe(); + return (unusedCtx, unusedRes) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public AfterModelCallbackSync asAfterModelCallbackSync() { - return (ctx, res) -> (Optional) callOptional(); + return (unusedCtx, unusedRes) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public BeforeToolCallback asBeforeToolCallback() { - return (invCtx, tool, toolArgs, toolCtx) -> (Maybe>) callMaybe(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) -> + (Maybe>) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public BeforeToolCallbackSync asBeforeToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx) -> (Optional>) callOptional(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) -> + (Optional>) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public AfterToolCallback asAfterToolCallback() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe>) callMaybe(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) -> + (Maybe>) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public AfterToolCallbackSync asAfterToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional>) callOptional(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) -> + (Optional>) callOptional(); } } From 9901307b1cb9be75f2262f116388f93cdcf3eeb6 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Feb 2026 11:44:19 -0800 Subject: [PATCH 19/63] feat: Add token usage threshold to TailRetentionEventCompactor PiperOrigin-RevId: 864976785 --- .../TailRetentionEventCompactor.java | 45 +++++++- .../TailRetentionEventCompactorTest.java | 108 ++++++++++++++++-- 2 files changed, 139 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java index c13a49cc1..b084de860 100644 --- a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java +++ b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java @@ -22,12 +22,15 @@ import com.google.adk.events.EventCompaction; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.common.collect.Lists; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.ListIterator; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,6 +41,7 @@ *

  • Keeps the {@code retentionSize} most recent events raw. *
  • Compacts all events that never compacted and older than the retained tail, including the * most recent compaction event, into a new summary event. + *
  • Triggers compaction only if the prompt token count exceeds the {@code tokenThreshold}. *
  • The new summary event is generated by the {@link BaseEventSummarizer}. *
  • Appends this new summary event to the end of the event stream. * @@ -52,10 +56,15 @@ public final class TailRetentionEventCompactor implements EventCompactor { private final BaseEventSummarizer summarizer; private final int retentionSize; + private final int tokenThreshold; - public TailRetentionEventCompactor(BaseEventSummarizer summarizer, int retentionSize) { + public TailRetentionEventCompactor( + BaseEventSummarizer summarizer, int retentionSize, int tokenThreshold) { + checkArgument(tokenThreshold >= 0, "tokenThreshold must be non-negative"); + checkArgument(retentionSize >= 0, "retentionSize must be non-negative"); this.summarizer = summarizer; this.retentionSize = retentionSize; + this.tokenThreshold = tokenThreshold; } @Override @@ -63,10 +72,36 @@ public Completable compact(Session session, BaseSessionService sessionService) { checkArgument(summarizer != null, "Missing BaseEventSummarizer for event compaction"); logger.debug("Running tail retention event compaction for session {}", session.id()); - return Completable.fromMaybe( - getCompactionEvents(session.events()) - .flatMap(summarizer::summarizeEvents) - .flatMapSingle(e -> sessionService.appendEvent(session, e))); + return Maybe.just(session.events()) + .filter(this::shouldCompact) + .flatMap(events -> getCompactionEvents(events)) + .flatMap(summarizer::summarizeEvents) + .flatMapSingle(e -> sessionService.appendEvent(session, e)) + .ignoreElement(); + } + + private boolean shouldCompact(List events) { + int count = getLatestPromptTokenCount(events).orElse(0); + + // TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not + // available. + if (count <= tokenThreshold) { + logger.debug( + "Skipping compaction. Prompt token count {} is within threshold {}", + count, + tokenThreshold); + return false; + } + return true; + } + + private Optional getLatestPromptTokenCount(List events) { + return Lists.reverse(events).stream() + .map(Event::usageMetadata) + .flatMap(Optional::stream) + .map(GenerateContentResponseUsageMetadata::promptTokenCount) + .flatMap(Optional::stream) + .findFirst(); } /** diff --git a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java index b4a6c3474..3260fbe1e 100644 --- a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java +++ b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java @@ -17,6 +17,7 @@ package com.google.adk.summarizer; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; @@ -30,6 +31,7 @@ import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -52,15 +54,91 @@ public class TailRetentionEventCompactorTest { @Mock private BaseEventSummarizer mockSummarizer; @Captor private ArgumentCaptor> eventListCaptor; + @Test + public void constructor_negativeTokenThreshold_throwsException() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> new TailRetentionEventCompactor(mockSummarizer, 2, -1))) + .hasMessageThat() + .contains("tokenThreshold must be non-negative"); + } + + @Test + public void constructor_negativeRetentionSize_throwsException() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> new TailRetentionEventCompactor(mockSummarizer, -1, 100))) + .hasMessageThat() + .contains("retentionSize must be non-negative"); + } + + @Test + // TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is + // not available. + public void compaction_skippedWhenTokenUsageMissing() { + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + createEvent(3, "Retain2")); // No usage metadata + Session session = Session.builder("id").events(events).build(); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compaction_skippedWhenTokenUsageBelowThreshold() { + // Threshold is 300, usage is 200. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + withUsage(createEvent(3, "Retain2"), 200)); + Session session = Session.builder("id").events(events).build(); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compaction_happensWhenTokenUsageAboveThreshold() { + // Threshold is 300, usage is 400. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300); + Event event3 = withUsage(createEvent(3, "Retain2"), 400); + ImmutableList events = + ImmutableList.of(createEvent(1, "Event1"), createEvent(2, "Retain1"), event3); + Session session = Session.builder("id").events(events).build(); + Event summaryEvent = createEvent(4, "Summary"); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer).summarizeEvents(any()); + verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent)); + } + @Test public void compact_notEnoughEvents_doesNothing() { ImmutableList events = ImmutableList.of( - createEvent(1, "Event1"), createEvent(2, "Event2"), createEvent(3, "Event3")); + createEvent(1, "Event1"), + createEvent(2, "Event2"), + withUsage(createEvent(3, "Event3"), 200)); Session session = Session.builder("id").events(events).build(); - // Retention size 5 > 3 events - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5); + // Retention size 5 > 3 events. Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -73,14 +151,17 @@ public void compact_respectRetentionSize_summarizesCorrectEvents() { // Retention size is 2. ImmutableList events = ImmutableList.of( - createEvent(1, "Event1"), createEvent(2, "Retain1"), createEvent(3, "Retain2")); + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + withUsage(createEvent(3, "Retain2"), 200)); Session session = Session.builder("id").events(events).build(); Event compactedEvent = createCompactedEvent(1, 1, "Summary", 4); when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2); + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -121,14 +202,15 @@ public void compact_withRetainedEventsPhysicallyBeforeCompaction_includesThem() createCompactedEvent( /* startTimestamp= */ 1, /* endTimestamp= */ 2, "C1", /* eventTimestamp= */ 4), createEvent(5, "E5"), - createEvent(6, "E6")); + withUsage(createEvent(6, "E6"), 200)); Session session = Session.builder("id").events(events).build(); Event compactedEvent = createCompactedEvent(1, 5, "Summary C1-E5", 7); when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1); + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -181,14 +263,15 @@ public void compact_withMultipleCompactionEvents_respectsCompactionBoundary() { createEvent(7, "E7"), createCompactedEvent( /* startTimestamp= */ 1, /* endTimestamp= */ 3, "C2", /* eventTimestamp= */ 8), - createEvent(9, "E9")); + withUsage(createEvent(9, "E9"), 200)); Session session = Session.builder("id").events(events).build(); Event compactedEvent = createCompactedEvent(1, 4, "Summary C2-E4", 10); when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); - TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3); + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3, 100); compactor.compact(session, mockSessionService).test().assertComplete(); @@ -224,6 +307,13 @@ private static String getPromptText(Event event) { .orElseThrow(); } + private Event withUsage(Event event, int tokens) { + return event.toBuilder() + .usageMetadata( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(tokens).build()) + .build(); + } + private Event createCompactedEvent( long startTimestamp, long endTimestamp, String content, long eventTimestamp) { return Event.builder() From af1fafed0470c8afe81679a495ed61664a2cee1a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Feb 2026 15:13:16 -0800 Subject: [PATCH 20/63] feat: Add Compaction RequestProcessor for event compaction in llm flow PiperOrigin-RevId: 865067193 --- .../google/adk/flows/llmflows/Compaction.java | 59 +++++++ .../adk/flows/llmflows/CompactionTest.java | 157 ++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 core/src/main/java/com/google/adk/flows/llmflows/Compaction.java create mode 100644 core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Compaction.java b/core/src/main/java/com/google/adk/flows/llmflows/Compaction.java new file mode 100644 index 000000000..6646f0ff7 --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/Compaction.java @@ -0,0 +1,59 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.adk.summarizer.TailRetentionEventCompactor; +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; + +/** Request processor that performs event compaction. */ +public class Compaction implements RequestProcessor { + + @Override + public Single processRequest( + InvocationContext context, LlmRequest request) { + Optional configOpt = context.eventsCompactionConfig(); + + if (configOpt.isEmpty()) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + + EventsCompactionConfig config = configOpt.get(); + + if (config.tokenThreshold() == null || config.eventRetentionSize() == null) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + + // Extract out the retention size and token threshold from the new config. + int retentionSize = config.eventRetentionSize(); + int tokenThreshold = config.tokenThreshold(); + + // Summarizer will not be missing since the runner will always add a default one if missing. + TailRetentionEventCompactor compactor = + new TailRetentionEventCompactor(config.summarizer(), retentionSize, tokenThreshold); + + return compactor + .compact(context.session(), context.sessionService()) + .andThen( + Single.just( + RequestProcessor.RequestProcessingResult.create(request, ImmutableList.of()))); + } +} diff --git a/core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java b/core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java new file mode 100644 index 000000000..3ceba5641 --- /dev/null +++ b/core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.adk.summarizer.BaseEventSummarizer; +import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CompactionTest { + + private InvocationContext context; + private LlmRequest request; + private Session session; + private BaseSessionService sessionService; + private BaseEventSummarizer summarizer; + + @Before + public void setUp() { + context = mock(InvocationContext.class); + request = LlmRequest.builder().build(); + session = Session.builder("test-session").build(); + sessionService = mock(BaseSessionService.class); + summarizer = mock(BaseEventSummarizer.class); + + when(context.session()).thenReturn(session); + when(context.sessionService()).thenReturn(sessionService); + } + + @Test + public void processRequest_noConfig_doesNothing() { + when(context.eventsCompactionConfig()).thenReturn(Optional.empty()); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + verify(sessionService, never()).appendEvent(any(), any()); + } + + @Test + public void processRequest_withConfig_triggersCompaction() { + // Setup config with threshold 100 + EventsCompactionConfig config = new EventsCompactionConfig(5, 1, summarizer, 100, 2); + when(context.eventsCompactionConfig()).thenReturn(Optional.of(config)); + + // Setup events with usage > 100 to trigger compaction + Event event1 = mock(Event.class); + Event event2 = mock(Event.class); + Event event3 = mock(Event.class); + when(event3.usageMetadata()) + .thenReturn( + Optional.of( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(200).build())); + + session = + Session.builder("test-session").events(ImmutableList.of(event1, event2, event3)).build(); + when(context.session()).thenReturn(session); + + // Summarizer mock + Event summaryEvent = mock(Event.class); + when(summarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(sessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + // Verify compaction happened and result was appended + verify(sessionService).appendEvent(eq(session), eq(summaryEvent)); + } + + @Test + public void processRequest_withConfig_skipsCompactionIfBelowThreshold() { + // Setup config with threshold 500 + EventsCompactionConfig config = new EventsCompactionConfig(5, 1, summarizer, 500, 2); + when(context.eventsCompactionConfig()).thenReturn(Optional.of(config)); + + // Setup events with usage 200 (below 500) + Event event3 = mock(Event.class); + when(event3.usageMetadata()) + .thenReturn( + Optional.of( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(200).build())); + + session = Session.builder("test-session").events(ImmutableList.of(event3)).build(); + when(context.session()).thenReturn(session); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + // Verify NO compaction + verify(sessionService, never()).appendEvent(any(), any()); + } + + @Test + public void processRequest_withConfig_nullRetentionSize_doesNothing() { + // Setup config with retentionSize = null + EventsCompactionConfig config = new EventsCompactionConfig(5, 1, summarizer, 100, null); + when(context.eventsCompactionConfig()).thenReturn(Optional.of(config)); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + // Verify NO compaction and session.events() is not called + verify(sessionService, never()).appendEvent(any(), any()); + verify(context, never()).session(); + } +} From 66e22964e67d0756e3351dae93e18aa5ae73f22e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 3 Feb 2026 16:09:33 -0800 Subject: [PATCH 21/63] feat: Reorder compaction events in chronological order PiperOrigin-RevId: 865089943 --- .../google/adk/flows/llmflows/Contents.java | 155 +++++++++++++----- .../adk/flows/llmflows/ContentsTest.java | 99 +++++++++++ 2 files changed, 217 insertions(+), 37 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 0c415f1a8..171dab972 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -38,7 +38,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.ListIterator; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -175,9 +174,18 @@ private boolean isEmptyContent(Event event) { /** * Filters events that are covered by compaction events by identifying compacted ranges and - * filters out events that are covered by compaction summaries + * filters out events that are covered by compaction summaries. Also filters out redundant + * compaction events (i.e., those fully covered by a later compaction event). * - *

    Example of input + *

    Compaction events are inserted into the stream relative to the events they cover. + * Specifically, a compaction event is placed immediately before the first retained event that + * follows the compaction range (or at the end of the covered range if no events are retained). + * This ensures a logical flow of "Summary of History" -> "Recent/Retained Events". + * + *

    Case 1: Sliding Window + Retention + * + *

    Compaction events have some overlap but do not fully cover each other. Therefore, all + * compaction events are preserved, as well as the final retained events. * *

        * [
    @@ -185,7 +193,7 @@ private boolean isEmptyContent(Event event) {
        *   event_2(timestamp=2),
        *   compaction_1(event_1, event_2, timestamp=3, content=summary_1_2, startTime=1, endTime=2),
        *   event_3(timestamp=4),
    -   *   compaction_2(event_2, event_3, timestamp=5, content=summary_2_3, startTime=2, endTime=3),
    +   *   compaction_2(event_2, event_3, timestamp=5, content=summary_2_3, startTime=2, endTime=4),
        *   event_4(timestamp=6)
        * ]
        * 
    @@ -200,50 +208,123 @@ private boolean isEmptyContent(Event event) { * ] * * - * Compaction events are always strictly in order based on event timestamp. + *

    Case 2: Rolling Summary + Retention + * + *

    The newer compaction event fully covers the older one. Therefore, the older compaction event + * is removed, leaving only the latest summary and the final retained events. + * + *

    +   * [
    +   *   event_1(timestamp=1),
    +   *   event_2(timestamp=2),
    +   *   event_3(timestamp=3),
    +   *   event_4(timestamp=4),
    +   *   compaction_1(event_1, timestamp=5, content=summary_1, startTime=1, endTime=1),
    +   *   event_6(timestamp=6),
    +   *   event_7(timestamp=7),
    +   *   compaction_2(compaction_1, event_2, event_3, timestamp=8, content=summary_1_3, startTime=1, endTime=3),
    +   *   event_9(timestamp=9)
    +   * ]
    +   * 
    + * + * Will result in the following events output + * + *
    +   * [
    +   *   compaction_2,
    +   *   event_4,
    +   *   event_6,
    +   *   event_7,
    +   *   event_9
    +   * ]
    +   * 
    * * @param events the list of event to filter. * @return a new list with compaction applied. */ private List processCompactionEvent(List events) { + // Step 1: Split events into compaction events and regular events. + List compactionEvents = new ArrayList<>(); + List regularEvents = new ArrayList<>(); + for (Event event : events) { + if (event.actions().compaction().isPresent()) { + compactionEvents.add(event); + } else { + regularEvents.add(event); + } + } + + // Step 2: Remove redundant compaction events (overlapping ones). + compactionEvents = removeOverlappingCompactions(compactionEvents); + + // Step 3: Merge regular events and compaction events based on timestamps. + // We iterate backwards from the latest to the earliest event. List result = new ArrayList<>(); - ListIterator iter = events.listIterator(events.size()); - Long lastCompactionStartTime = null; - Long lastCompactionEndTime = null; - - while (iter.hasPrevious()) { - Event event = iter.previous(); - EventCompaction compaction = event.actions().compaction().orElse(null); - if (compaction == null) { - if (lastCompactionStartTime == null - || event.timestamp() < lastCompactionStartTime - || (lastCompactionEndTime != null && event.timestamp() > lastCompactionEndTime)) { - result.add(event); - } - continue; + int c = compactionEvents.size() - 1; + int e = regularEvents.size() - 1; + while (e >= 0 && c >= 0) { + Event event = regularEvents.get(e); + EventCompaction compaction = compactionEvents.get(c).actions().compaction().get(); + + if (event.timestamp() >= compaction.startTimestamp() + && event.timestamp() <= compaction.endTimestamp()) { + // If the event is covered by compaction, skip it. + e--; + } else if (event.timestamp() > compaction.endTimestamp()) { + // If the event is after compaction, keep it. + result.add(event); + e--; + } else { + // Otherwise the event is before the compaction, let's move to the next compaction event; + result.add(createCompactionEvent(compactionEvents.get(c))); + c--; + } + } + // Flush any remaining compactions. + while (c >= 0) { + result.add(createCompactionEvent(compactionEvents.get(c))); + c--; + } + // Flush any remaining regular events. + while (e >= 0) { + result.add(regularEvents.get(e)); + e--; + } + return Lists.reverse(result); + } + + private static List removeOverlappingCompactions(List events) { + List result = new ArrayList<>(); + // Iterate backwards to prioritize later compactions + for (int i = events.size() - 1; i >= 0; i--) { + Event current = events.get(i); + EventCompaction c = current.actions().compaction().get(); + + // Check if this compaction is covered by the last compaction we've already kept. + boolean covered = false; + if (!result.isEmpty()) { + EventCompaction lastKept = Iterables.getLast(result).actions().compaction().get(); + covered = + c.startTimestamp() >= lastKept.startTimestamp() + && c.endTimestamp() <= lastKept.endTimestamp(); + } + + if (!covered) { + result.add(current); } - // Create a new event for the compaction event in the result. - result.add( - Event.builder() - .timestamp(compaction.endTimestamp()) - .author("model") - .content(compaction.compactedContent()) - .branch(event.branch()) - .invocationId(event.invocationId()) - .actions(event.actions()) - .build()); - lastCompactionStartTime = - lastCompactionStartTime == null - ? compaction.startTimestamp() - : Long.min(lastCompactionStartTime, compaction.startTimestamp()); - lastCompactionEndTime = - lastCompactionEndTime == null - ? compaction.endTimestamp() - : Long.max(lastCompactionEndTime, compaction.endTimestamp()); } return Lists.reverse(result); } + private static Event createCompactionEvent(Event event) { + EventCompaction compaction = event.actions().compaction().get(); + return event.toBuilder() + .timestamp(compaction.endTimestamp()) + .author("model") + .content(compaction.compactedContent()) + .build(); + } + /** Whether the event is a reply from another agent. */ private static boolean isOtherAgentReply(String agentName, Event event) { return !agentName.isEmpty() diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 85895088f..82a57ed4f 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -577,6 +577,105 @@ public void processRequest_compactionWithUncompactedEventsBetween() { .containsExactly("content 3", "Summary 1-2"); } + @Test + public void processRequest_rollingSummary_removesRedundancy() { + // Scenario: Rolling summary where a later summary covers a superset of the time range. + // Input: [E1(1), C1(Cover 1-1), E3(3), C2(Cover 1-3)] + // Expected: [C2] + // Explanation: C2 covers the range [1, 3], which includes the range covered by C1 [1, 1]. + // Therefore, C1 is redundant. E1 and E3 are also covered by C2. + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createCompactedEvent(1, 1, "C1"), + createUserEvent("e3", "E3", "inv3", 3), + createCompactedEvent(1, 3, "C2")); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("C2"); + } + + @Test + public void processRequest_rollingSummaryWithRetention() { + // Input: with retention size 3: [E1, E2, E3, E4, C1(Cover 1-1), E6, E7, C2(Cover 1-3), E9] + // Expected: [C2, E4, E6, E7, E9] + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createUserEvent("e2", "E2", "inv2", 2), + createUserEvent("e3", "E3", "inv3", 3), + createUserEvent("e4", "E4", "inv4", 4), + createCompactedEvent(1, 1, "C1"), + createUserEvent("e6", "E6", "inv6", 6), + createUserEvent("e7", "E7", "inv7", 7), + createCompactedEvent(1, 3, "C2"), + createUserEvent("e9", "E9", "inv9", 9)); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("C2", "E4", "E6", "E7", "E9"); + } + + @Test + public void processRequest_rollingSummary_preservesUncoveredHistory() { + // Input: [E1(1), E2(2), E3(3), E4(4), C1(2-2), E6(6), E7(7), C2(2-3), E9(9)] + // Expected: [E1, C2, E4, E6, E7, E9] + // E1 is before C1/C2 range, so it is preserved. + // C1 (2-2) is covered by C2 (2-3), so C1 is removed. + // E2, E3 are covered by C2. + // E4, E6, E7, E9 are retained. + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createUserEvent("e2", "E2", "inv2", 2), + createUserEvent("e3", "E3", "inv3", 3), + createUserEvent("e4", "E4", "inv4", 4), + createCompactedEvent(2, 2, "C1"), + createUserEvent("e6", "E6", "inv6", 6), + createUserEvent("e7", "E7", "inv7", 7), + createCompactedEvent(2, 3, "C2"), + createUserEvent("e9", "E9", "inv9", 9)); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("E1", "C2", "E4", "E6", "E7", "E9"); + } + + @Test + public void processRequest_slidingWindow_preservesOverlappingCompactions() { + // Case 1: Sliding Window + Retention + // Input: [E1(1), E2(2), E3(3), C1(1-2), E4(5), C2(2-3), E5(7)] + // Overlap: C1 and C2 overlap at 2. C1 is NOT redundant (start 1 < start 2). + // Expected: [C1, C2, E4, E5] + // E1(1) covered by C1. + // E2(2) covered by C1 (and C2). + // E3(3) covered by C2. + // E4(5) retained. + // E5(7) retained. + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createUserEvent("e2", "E2", "inv2", 2), + createUserEvent("e3", "E3", "inv3", 3), + createCompactedEvent(1, 2, "C1"), + createUserEvent("e4", "E4", "inv4", 5), + createCompactedEvent(2, 3, "C2"), + createUserEvent("e5", "E5", "inv5", 7)); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("C1", "C2", "E4", "E5"); + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id) From 072767cd73b3472c66ad75242056d4a606fb279f Mon Sep 17 00:00:00 2001 From: Simon Su Date: Thu, 18 Dec 2025 19:39:37 +1100 Subject: [PATCH 22/63] Fix rearrangeEventsForAsyncFunctionResponsesInHistory to ensure function responses are merged --- .../google/adk/flows/llmflows/Contents.java | 3 +- .../adk/flows/llmflows/ContentsTest.java | 72 +++++++++++++++++-- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 171dab972..f45461626 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -564,8 +564,7 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int i = 0; i < events.size(); i++) { Event event = events.get(i); - // Skip response events that will be processed via responseEventsBuffer - if (processedResponseIndices.contains(i)) { + if (!event.functionResponses().isEmpty()) { continue; } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 82a57ed4f..a8a862b51 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -203,9 +203,11 @@ public void rearrangeHistory_asyncFR_returnsRearrangedList() { public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { Event fcEvent = createFunctionCallEvent("fc1", "tool1", "call1"); Event frEvent1 = - createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "running")); + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); Event frEvent2 = - createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "running")); + Event frEvent3 = + createFunctionResponseEvent("fr3", "tool1", "call1", ImmutableMap.of("status", "done")); ImmutableList inputEvents = ImmutableList.of( createUserEvent("u1", "Query"), @@ -213,17 +215,75 @@ public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { createUserEvent("u2", "Wait"), frEvent1, createUserEvent("u3", "Done?"), - frEvent2); + frEvent2, + frEvent3, + createUserEvent("u4", "Follow up query")); List result = runContentsProcessor(inputEvents); - assertThat(result).hasSize(3); // u1, fc1, merged_fr + assertThat(result).hasSize(6); // u1, fc1, merged_fr, u2, u3, u4 assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); - assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check merged event + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check fcEvent Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(1); assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) - .containsExactly("status", "done"); // Last FR wins + .containsExactly("status", "done"); // Last FR wins (frEvent3) + assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 + assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 + assertThat(result.get(5)).isEqualTo(inputEvents.get(7).content().get()); // u4 + } + + @Test + public void rearrangeHistory_multipleFRsForMultipleFC_returnsMergedFR() { + Event fcEvent1 = createFunctionCallEvent("fc1", "tool1", "call1"); + Event fcEvent2 = createFunctionCallEvent("fc2", "tool1", "call2"); + + Event frEvent1 = + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); + Event frEvent2 = + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); + + Event frEvent3 = + createFunctionResponseEvent("fr3", "tool1", "call2", ImmutableMap.of("status", "pending")); + Event frEvent4 = + createFunctionResponseEvent("fr4", "tool1", "call2", ImmutableMap.of("status", "done")); + + ImmutableList inputEvents = + ImmutableList.of( + createUserEvent("u1", "I"), + fcEvent1, + createUserEvent("u2", "am"), + frEvent1, + createUserEvent("u3", "waiting"), + frEvent2, + createUserEvent("u4", "for"), + fcEvent2, + createUserEvent("u5", "you"), + frEvent3, + createUserEvent("u6", "to"), + frEvent4, + createUserEvent("u7", "Follow up query")); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).hasSize(11); // u1, fc1, frEvent2, u2, u3, u4, fc2, frEvent4, u5, u6, u7 + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); // u1 + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // fc1 + Content mergedContent = result.get(2); + assertThat(mergedContent.parts().get()).hasSize(1); + assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) + .containsExactly("status", "done"); // Last FR wins (frEvent2) + assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 + assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 + assertThat(result.get(5)).isEqualTo(inputEvents.get(6).content().get()); // u4 + assertThat(result.get(6)).isEqualTo(inputEvents.get(7).content().get()); // fc2 + Content mergedContent2 = result.get(7); + assertThat(mergedContent2.parts().get()).hasSize(1); + assertThat(mergedContent2.parts().get().get(0).functionResponse().get().response().get()) + .containsExactly("status", "done"); // Last FR wins (frEvent4) + assertThat(result.get(8)).isEqualTo(inputEvents.get(8).content().get()); // u5 + assertThat(result.get(9)).isEqualTo(inputEvents.get(10).content().get()); // u6 + assertThat(result.get(10)).isEqualTo(inputEvents.get(12).content().get()); // u7 } @Test From 503caa6393635a56c672a6592747bcb6e034b8a1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Feb 2026 10:12:15 -0800 Subject: [PATCH 23/63] feat: Adding validation to BaseAgent and RunConfig PiperOrigin-RevId: 865462678 --- .../java/com/google/adk/agents/BaseAgent.java | 22 +++++++++++++++++++ .../java/com/google/adk/agents/RunConfig.java | 3 +++ .../com/google/adk/agents/BaseAgentTest.java | 15 +++++++++++++ .../com/google/adk/agents/RunConfigTest.java | 8 +++++++ 4 files changed, 48 insertions(+) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 255d59c4d..e7af0d1ea 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -16,7 +16,9 @@ package com.google.adk.agents; +import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -36,12 +38,17 @@ import java.util.List; import java.util.Optional; import java.util.function.Function; +import java.util.regex.Pattern; import java.util.stream.Stream; import org.jspecify.annotations.Nullable; /** Base class for all agents. */ public abstract class BaseAgent { + // Pattern for valid agent names. + private static final String IDENTIFIER_REGEX = "^_?[a-zA-Z0-9]*([. _-][a-zA-Z0-9]+)*$"; + private static final Pattern IDENTIFIER_PATTERN = Pattern.compile(IDENTIFIER_REGEX); + /** The agent's name. Must be a unique identifier within the agent tree. */ private final String name; @@ -79,6 +86,7 @@ public BaseAgent( @Nullable List subAgents, @Nullable List beforeAgentCallback, @Nullable List afterAgentCallback) { + validateAgentName(name); this.name = name; this.description = description; this.parentAgent = null; @@ -96,6 +104,20 @@ public BaseAgent( } } + private static void validateAgentName(String name) { + if (isNullOrEmpty(name)) { + throw new IllegalArgumentException("Agent name cannot be null or empty."); + } + if (!IDENTIFIER_PATTERN.matcher(name).matches()) { + throw new IllegalArgumentException( + format("Agent name '%s' does not match regex '%s'.", name, IDENTIFIER_REGEX)); + } + if (name.equals("user")) { + throw new IllegalArgumentException( + "Agent name cannot be 'user'; reserved for end-user input."); + } + } + /** * Gets the agent's unique name. * diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 308169e36..1ca203eaf 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -134,6 +134,9 @@ public abstract Builder setInputAudioTranscription( public RunConfig build() { RunConfig runConfig = autoBuild(); + if (runConfig.maxLlmCalls() == Integer.MAX_VALUE) { + throw new IllegalArgumentException("maxLlmCalls should be less than Integer.MAX_VALUE."); + } if (runConfig.maxLlmCalls() < 0) { logger.warn( "maxLlmCalls is negative. This will result in no enforcement on total" diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index dec66a77c..4afce04ee 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -336,4 +337,18 @@ public void runLive_invokesRunLiveImpl() { assertThat(results.get(0).content()).hasValue(runLiveImplContent); assertThat(runLiveCallback.wasCalled()).isTrue(); } + + @Test + public void constructor_invalidName_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, + () -> new TestBaseAgent("invalid name?", "description", null, null, null)); + } + + @Test + public void constructor_userName_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, + () -> new TestBaseAgent("user", "description", null, null, null)); + } } diff --git a/core/src/test/java/com/google/adk/agents/RunConfigTest.java b/core/src/test/java/com/google/adk/agents/RunConfigTest.java index 7b6e7558f..1c416aa72 100644 --- a/core/src/test/java/com/google/adk/agents/RunConfigTest.java +++ b/core/src/test/java/com/google/adk/agents/RunConfigTest.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; import com.google.genai.types.AudioTranscriptionConfig; @@ -114,4 +115,11 @@ public void testInputAudioTranscriptionOnly() { assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.BIDI); assertThat(runConfig.responseModalities()).containsExactly(new Modality(Modality.Known.AUDIO)); } + + @Test + public void testMaxLlmCalls_integerMaxValue_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, + () -> RunConfig.builder().setMaxLlmCalls(Integer.MAX_VALUE).build()); + } } From 279c977d9eefda39159dd4bd86acea03a47c6101 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 4 Feb 2026 14:07:41 -0800 Subject: [PATCH 24/63] fix: Propagate trace context across async boundaries This change ensures the trace context is propagated across asynchronous boundaries for LLM and Tool calls. PiperOrigin-RevId: 865572112 --- .../adk/flows/llmflows/BaseLlmFlow.java | 38 ++++++++++++++----- .../google/adk/flows/llmflows/Functions.java | 37 +++++++++++------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index cfbadb9fe..8e654485c 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -144,11 +144,15 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } + Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> - buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest)); + updatedResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); + } + }); } /** @@ -160,7 +164,10 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { + InvocationContext context, + LlmRequest llmRequest, + Event eventForCallbackUsage, + Context parentTracingContext) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); @@ -180,7 +187,7 @@ private Flowable callLlm( Span llmCallSpan = Tracing.getTracer() .spanBuilder("call_llm") - .setParent(Context.current()) + .setParent(parentTracingContext) .startSpan(); try (Scope scope = llmCallSpan.makeCurrent()) { @@ -333,6 +340,7 @@ private Single handleAfterModelCallback( * @throws IllegalStateException if a transfer agent is specified but not found. */ private Flowable runOneStep(InvocationContext context) { + Context parentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(context, llmRequestRef); @@ -363,10 +371,12 @@ private Flowable runOneStep(InvocationContext context) { // events with fresh timestamp. mutableEventTemplate.setTimestamp(0L); - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + return callLlm( + context, llmRequestAfterPreprocess, mutableEventTemplate, parentContext) .concatMap( - llmResponse -> - postprocess( + llmResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, @@ -380,7 +390,9 @@ private Flowable runOneStep(InvocationContext context) { + " next LlmResponse", oldId, mutableEventTemplate.id()); - })) + }); + } + }) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -421,6 +433,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return currentStepEvents; } + Context parentContext = Context.current(); return currentStepEvents.concatWith( currentStepEvents .toList() @@ -435,7 +448,12 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return Flowable.defer(() -> run(invocationContext, stepsCompleted + 1)); + return Flowable.defer( + () -> { + try (Scope scope = parentContext.makeCurrent()) { + return run(invocationContext, stepsCompleted + 1); + } + }); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 3bb57faee..a6fb74d88 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -245,6 +245,7 @@ private static Function> getFunctionCallMapper( Map tools, Map toolConfirmations, boolean isLive) { + Context parentContext = Context.current(); return functionCall -> { BaseTool tool = tools.get(functionCall.name().get()); ToolContext toolContext = @@ -259,14 +260,19 @@ private static Function> getFunctionCallMapper( maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) .switchIfEmpty( Maybe.defer( - () -> - isLive + () -> { + try (Scope scope = parentContext.makeCurrent()) { + return isLive ? processFunctionLive( invocationContext, tool, toolContext, functionCall, functionArgs) - : callTool(tool, functionArgs, toolContext))); + : callTool(tool, functionArgs, toolContext); + } + })); - return postProcessFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); + try (Scope scope = parentContext.makeCurrent()) { + return postProcessFunctionResult( + maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); + } }; } @@ -372,6 +378,7 @@ private static Maybe postProcessFunctionResult( Map functionArgs, ToolContext toolContext, boolean isLive) { + Context parentContext = Context.current(); return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) @@ -393,14 +400,17 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) .flatMapMaybe( finalOptionalResult -> { - Map finalFunctionResult = finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); + try (Scope scope = parentContext.makeCurrent()) { + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event functionResponseEvent = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + return Maybe.just(functionResponseEvent); } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); }); }); } @@ -552,12 +562,13 @@ private static Maybe> maybeInvokeAfterToolCall( private static Maybe> callTool( BaseTool tool, Map args, ToolContext toolContext) { Tracer tracer = Tracing.getTracer(); + Context parentContext = Context.current(); return Maybe.defer( () -> { Span span = tracer .spanBuilder("tool_call [" + tool.name() + "]") - .setParent(Context.current()) + .setParent(parentContext) .startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceToolCall(args); From 5607f644c95a053bf381c2021879e6f31d5c6bde Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Feb 2026 09:08:47 -0800 Subject: [PATCH 25/63] fix: Fixing Vertex session storage PiperOrigin-RevId: 865975883 --- .../main/java/com/google/adk/sessions/SessionJsonConverter.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index d1a661a91..220d8f205 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -95,6 +95,7 @@ static String convertEventToJson(Event event) { actionsJson.put("artifactDelta", event.actions().artifactDelta()); actionsJson.put("transferAgent", event.actions().transferToAgent()); actionsJson.put("escalate", event.actions().escalate()); + actionsJson.put("endInvocation", event.actions().endInvocation()); actionsJson.put("requestedAuthConfigs", event.actions().requestedAuthConfigs()); actionsJson.put("requestedToolConfirmations", event.actions().requestedToolConfirmations()); actionsJson.put("compaction", event.actions().compaction()); From 76f86c54eb1a242e604f7b43e3ee18940168b6ec Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Feb 2026 11:11:37 -0800 Subject: [PATCH 26/63] feat: Skip post-invocation compaction if parameters not set PiperOrigin-RevId: 866029385 --- .../java/com/google/adk/runner/Runner.java | 7 +- .../summarizer/EventsCompactionConfig.java | 8 ++- .../com/google/adk/runner/RunnerTest.java | 64 +++++++++++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 5c275ab56..e543f7d69 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -580,6 +580,7 @@ private Flowable runAgentWithFreshSession( private Completable compactEvents(Session session) { return Optional.ofNullable(eventsCompactionConfig) + .filter(EventsCompactionConfig::hasSlidingWindowCompactionConfig) .map(SlidingWindowEventCompactor::new) .map(c -> c.compact(session, sessionService)) .orElseGet(Completable::complete); @@ -817,7 +818,11 @@ private static EventsCompactionConfig createEventsCompactionConfig( new IllegalArgumentException( "No BaseLlm model available for event compaction")); return new EventsCompactionConfig( - config.compactionInterval(), config.overlapSize(), summarizer); + config.compactionInterval(), + config.overlapSize(), + summarizer, + config.tokenThreshold(), + config.eventRetentionSize()); } // TODO: run statelessly diff --git a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java index 04dc11e10..b61cd2008 100644 --- a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java +++ b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java @@ -33,8 +33,8 @@ * null, no event retention limit will be enforced. */ public record EventsCompactionConfig( - int compactionInterval, - int overlapSize, + @Nullable Integer compactionInterval, + @Nullable Integer overlapSize, @Nullable BaseEventSummarizer summarizer, @Nullable Integer tokenThreshold, @Nullable Integer eventRetentionSize) { @@ -47,4 +47,8 @@ public EventsCompactionConfig( int compactionInterval, int overlapSize, @Nullable BaseEventSummarizer summarizer) { this(compactionInterval, overlapSize, summarizer, null, null); } + + public boolean hasSlidingWindowCompactionConfig() { + return compactionInterval != null && compactionInterval > 0 && overlapSize != null; + } } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index b0dbedcd6..a01f4201e 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -179,6 +179,70 @@ public void eventsCompaction_enabled() { "user: summary 2"); } + @Test + public void eventsCompaction_withNullOverlap_doesNotCompact() { + TestLlm testLlm = + createTestLlm( + createLlmResponse(createContent("llm 1")), createLlmResponse(createContent("llm 2"))); + LlmAgent agent = createTestAgent(testLlm); + + Runner runner = + Runner.builder() + .app( + App.builder() + .name(this.runner.appName()) + .rootAgent(agent) + .eventsCompactionConfig(new EventsCompactionConfig(1, null, null, null, null)) + .build()) + .sessionService(this.runner.sessionService()) + .build(); + + var unused1 = + runner.runAsync("user", session.id(), createContent("user 1")).toList().blockingGet(); + var unused2 = + runner.runAsync("user", session.id(), createContent("user 2")).toList().blockingGet(); + + Session updatedSession = + runner + .sessionService() + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(simplifyEvents(updatedSession.events())) + .containsExactly("user: user 1", "test agent: llm 1", "user: user 2", "test agent: llm 2"); + } + + @Test + public void eventsCompaction_withNullInterval_doesNotCompact() { + TestLlm testLlm = + createTestLlm( + createLlmResponse(createContent("llm 1")), createLlmResponse(createContent("llm 2"))); + LlmAgent agent = createTestAgent(testLlm); + + Runner runner = + Runner.builder() + .app( + App.builder() + .name(this.runner.appName()) + .rootAgent(agent) + .eventsCompactionConfig(new EventsCompactionConfig(null, 0, null, null, null)) + .build()) + .sessionService(this.runner.sessionService()) + .build(); + + var unused1 = + runner.runAsync("user", session.id(), createContent("user 1")).toList().blockingGet(); + var unused2 = + runner.runAsync("user", session.id(), createContent("user 2")).toList().blockingGet(); + + Session updatedSession = + runner + .sessionService() + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(simplifyEvents(updatedSession.events())) + .containsExactly("user: user 1", "test agent: llm 1", "user: user 2", "test agent: llm 2"); + } + @Test public void pluginDoesNothing() { var events = From 12defeedbaf6048bc83d484f421131051b7e81a5 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Feb 2026 11:14:45 -0800 Subject: [PATCH 27/63] feat: introduces context caching configuration for apps, ported from Python ADK This change introduces ContextCacheConfig, a configuration class for enabling and controlling context caching across all agents within an ADK application. The App class is updated to include an optional ContextCacheConfig, allowing applications to specify caching parameters such as cache intervals, TTL, and minimum token thresholds for caching. PiperOrigin-RevId: 866030713 --- .../google/adk/agents/ContextCacheConfig.java | 59 +++++++++++++++++++ .../main/java/com/google/adk/apps/App.java | 21 ++++++- 2 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/com/google/adk/agents/ContextCacheConfig.java diff --git a/core/src/main/java/com/google/adk/agents/ContextCacheConfig.java b/core/src/main/java/com/google/adk/agents/ContextCacheConfig.java new file mode 100644 index 000000000..084700d54 --- /dev/null +++ b/core/src/main/java/com/google/adk/agents/ContextCacheConfig.java @@ -0,0 +1,59 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.agents; + +import java.time.Duration; + +/** + * Configuration for context caching across all agents in an app. + * + *

    This configuration enables and controls context caching behavior for all LLM agents in an app. + * When this config is present on an app, context caching is enabled for all agents. When absent + * (null), context caching is disabled. + * + *

    Context caching can significantly reduce costs and improve response times by reusing + * previously processed context across multiple requests. + * + * @param maxInvocations Maximum number of invocations to reuse the same cache before refreshing it. + * Defaults to 10. + * @param ttl Time-to-live for cache. Defaults to 1800 seconds (30 minutes). + * @param minTokens Minimum estimated request tokens required to enable caching. This compares + * against the estimated total tokens of the request (system instruction + tools + contents). + * Context cache storage may have cost. Set higher to avoid caching small requests where + * overhead may exceed benefits. Defaults to 0. + */ +public record ContextCacheConfig(int maxInvocations, Duration ttl, int minTokens) { + + public ContextCacheConfig() { + this(10, Duration.ofSeconds(1800), 0); + } + + /** Returns TTL as string format for cache creation. */ + public String getTtlString() { + return ttl.getSeconds() + "s"; + } + + @Override + public String toString() { + return "ContextCacheConfig(maxInvocations=" + + maxInvocations + + ", ttl=" + + ttl.getSeconds() + + "s, minTokens=" + + minTokens + + ")"; + } +} diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 5be72bb5c..d6635d2e7 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -17,6 +17,7 @@ package com.google.adk.apps; import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.ContextCacheConfig; import com.google.adk.plugins.Plugin; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.common.collect.ImmutableList; @@ -41,18 +42,21 @@ public class App { private final ImmutableList plugins; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ResumabilityConfig resumabilityConfig; + @Nullable private final ContextCacheConfig contextCacheConfig; private App( String name, BaseAgent rootAgent, List plugins, @Nullable EventsCompactionConfig eventsCompactionConfig, - @Nullable ResumabilityConfig resumabilityConfig) { + @Nullable ResumabilityConfig resumabilityConfig, + @Nullable ContextCacheConfig contextCacheConfig) { this.name = name; this.rootAgent = rootAgent; this.plugins = ImmutableList.copyOf(plugins); this.eventsCompactionConfig = eventsCompactionConfig; this.resumabilityConfig = resumabilityConfig; + this.contextCacheConfig = contextCacheConfig; } public String name() { @@ -77,6 +81,11 @@ public ResumabilityConfig resumabilityConfig() { return resumabilityConfig; } + @Nullable + public ContextCacheConfig contextCacheConfig() { + return contextCacheConfig; + } + /** Builder for {@link App}. */ public static class Builder { private String name; @@ -84,6 +93,7 @@ public static class Builder { private List plugins = ImmutableList.of(); @Nullable private EventsCompactionConfig eventsCompactionConfig; @Nullable private ResumabilityConfig resumabilityConfig; + @Nullable private ContextCacheConfig contextCacheConfig; @CanIgnoreReturnValue public Builder name(String name) { @@ -115,6 +125,12 @@ public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { return this; } + @CanIgnoreReturnValue + public Builder contextCacheConfig(ContextCacheConfig contextCacheConfig) { + this.contextCacheConfig = contextCacheConfig; + return this; + } + public App build() { if (name == null) { throw new IllegalStateException("App name must be provided."); @@ -123,7 +139,8 @@ public App build() { throw new IllegalStateException("Root agent must be provided."); } validateAppName(name); - return new App(name, rootAgent, plugins, eventsCompactionConfig, resumabilityConfig); + return new App( + name, rootAgent, plugins, eventsCompactionConfig, resumabilityConfig, contextCacheConfig); } } From ac05fde31ec6a67baf7cacb6144f5912eca029ac Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Feb 2026 11:59:49 -0800 Subject: [PATCH 28/63] feat: add eventId in CallbackContext and ToolContext This allows callbacks and tools to know which event triggered them. This is useful in telemetry and tracing. PiperOrigin-RevId: 866050613 --- .../google/adk/agents/CallbackContext.java | 19 +++++++++++++++++++ .../adk/flows/llmflows/BaseLlmFlow.java | 9 ++++++--- .../com/google/adk/tools/ToolContext.java | 18 ++++++++++++++---- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index f7bbdcdbe..808c737a3 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -31,6 +31,7 @@ public class CallbackContext extends ReadonlyContext { protected EventActions eventActions; private final State state; + private final String eventId; /** * Initializes callback context. @@ -39,9 +40,22 @@ public class CallbackContext extends ReadonlyContext { * @param eventActions Callback event actions. */ public CallbackContext(InvocationContext invocationContext, EventActions eventActions) { + this(invocationContext, eventActions, null); + } + + /** + * Initializes callback context. + * + * @param invocationContext Current invocation context. + * @param eventActions Callback event actions. + * @param eventId The ID of the event associated with this context. + */ + public CallbackContext( + InvocationContext invocationContext, EventActions eventActions, String eventId) { super(invocationContext); this.eventActions = eventActions != null ? eventActions : EventActions.builder().build(); this.state = new State(invocationContext.session().state(), this.eventActions.stateDelta()); + this.eventId = eventId; } /** Returns the delta-aware state of the current callback. */ @@ -55,6 +69,11 @@ public EventActions eventActions() { return eventActions; } + /** Returns the ID of the event associated with this context. */ + public String eventId() { + return eventId; + } + /** * Lists the filenames of the artifacts attached to the current session. * diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 8e654485c..6dfbf586c 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -237,7 +237,8 @@ private Flowable callLlm( private Single> handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Maybe pluginResult = context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); @@ -274,7 +275,8 @@ private Maybe handleOnModelErrorCallback( Event modelResponseEvent, Throwable throwable) { Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); Maybe pluginResult = @@ -308,7 +310,8 @@ private Maybe handleOnModelErrorCallback( private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Maybe pluginResult = context.pluginManager().afterModelCallback(callbackContext, llmResponse); diff --git a/core/src/main/java/com/google/adk/tools/ToolContext.java b/core/src/main/java/com/google/adk/tools/ToolContext.java index b421a8e58..5192d19ff 100644 --- a/core/src/main/java/com/google/adk/tools/ToolContext.java +++ b/core/src/main/java/com/google/adk/tools/ToolContext.java @@ -35,8 +35,9 @@ private ToolContext( InvocationContext invocationContext, EventActions eventActions, Optional functionCallId, - Optional toolConfirmation) { - super(invocationContext, eventActions); + Optional toolConfirmation, + @Nullable String eventId) { + super(invocationContext, eventActions, eventId); this.functionCallId = functionCallId; this.toolConfirmation = toolConfirmation; } @@ -125,7 +126,8 @@ public Builder toBuilder() { return new Builder(invocationContext) .actions(eventActions) .functionCallId(functionCallId.orElse(null)) - .toolConfirmation(toolConfirmation.orElse(null)); + .toolConfirmation(toolConfirmation.orElse(null)) + .eventId(eventId()); } @Override @@ -148,6 +150,7 @@ public static final class Builder { private EventActions eventActions = EventActions.builder().build(); // Default empty actions private Optional functionCallId = Optional.empty(); private Optional toolConfirmation = Optional.empty(); + private String eventId; private Builder(InvocationContext invocationContext) { this.invocationContext = invocationContext; @@ -171,8 +174,15 @@ public Builder toolConfirmation(ToolConfirmation toolConfirmation) { return this; } + @CanIgnoreReturnValue + public Builder eventId(String eventId) { + this.eventId = eventId; + return this; + } + public ToolContext build() { - return new ToolContext(invocationContext, eventActions, functionCallId, toolConfirmation); + return new ToolContext( + invocationContext, eventActions, functionCallId, toolConfirmation, eventId); } } } From 5dfc000c9019b4d11a33b35c71c2a04d1f657bf2 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Feb 2026 12:21:38 -0800 Subject: [PATCH 29/63] feat: Adding validation to BaseAgent PiperOrigin-RevId: 866059790 --- .../java/com/google/adk/agents/BaseAgent.java | 51 +++++++++++++++++-- .../com/google/adk/agents/BaseAgentTest.java | 11 ++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index e7af0d1ea..0db4dabb5 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -35,6 +35,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -90,13 +91,16 @@ public BaseAgent( this.name = name; this.description = description; this.parentAgent = null; - this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents); + this.subAgents = (subAgents != null) ? ImmutableList.copyOf(subAgents) : ImmutableList.of(); + validateSubAgents(this.name, this.subAgents); this.beforeAgentCallback = - beforeAgentCallback == null - ? ImmutableList.of() - : ImmutableList.copyOf(beforeAgentCallback); + (beforeAgentCallback != null) + ? ImmutableList.copyOf(beforeAgentCallback) + : ImmutableList.of(); this.afterAgentCallback = - afterAgentCallback == null ? ImmutableList.of() : ImmutableList.copyOf(afterAgentCallback); + (afterAgentCallback != null) + ? ImmutableList.copyOf(afterAgentCallback) + : ImmutableList.of(); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -104,6 +108,13 @@ public BaseAgent( } } + /** + * Validates the agent name. + * + * @param name The agent name to validate. + * @throws IllegalArgumentException if the agent name is null, empty, or does not match the + * identifier pattern. + */ private static void validateAgentName(String name) { if (isNullOrEmpty(name)) { throw new IllegalArgumentException("Agent name cannot be null or empty."); @@ -118,6 +129,36 @@ private static void validateAgentName(String name) { } } + /** + * Validates the sub-agents. + * + * @param name The name of the parent agent. + * @param subAgents The list of sub-agents to validate. + * @throws IllegalArgumentException if the sub-agents have duplicate names. + */ + private static void validateSubAgents( + String name, @Nullable List subAgents) { + if (subAgents == null) { + return; + } + HashSet subAgentNames = new HashSet<>(); + HashSet duplicateSubAgentNames = new HashSet<>(); + for (BaseAgent subAgent : subAgents) { + String subAgentName = subAgent.name(); + // NOTE: Mocked agents have null names because BaseAgent.name() is a final method that + // cannot be mocked. + if (subAgentName != null && !subAgentNames.add(subAgentName)) { + duplicateSubAgentNames.add(subAgentName); + } + } + if (!duplicateSubAgentNames.isEmpty()) { + throw new IllegalArgumentException( + format( + "Agent named '%s' has sub-agents with duplicate names: %s. Sub-agents: %s", + name, duplicateSubAgentNames, subAgents)); + } + } + /** * Gets the agent's unique name. * diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 4afce04ee..d435e90c3 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -351,4 +351,15 @@ public void constructor_userName_throwsIllegalArgumentException() { IllegalArgumentException.class, () -> new TestBaseAgent("user", "description", null, null, null)); } + + @Test + public void constructor_duplicateSubAgentNames_throwsIllegalArgumentException() { + TestBaseAgent subAgent1 = new TestBaseAgent("subAgent", "subAgent1", null, null, null); + TestBaseAgent subAgent2 = new TestBaseAgent("subAgent", "subAgent2", null, null, null); + assertThrows( + IllegalArgumentException.class, + () -> + new TestBaseAgent( + "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null)); + } } From dd601ca8ed939d42fa186113bf0dca31c6e4a6db Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 5 Feb 2026 16:14:35 -0800 Subject: [PATCH 30/63] fix: Reduce the logging level Wonder if we can downgrade those two info() level loggings. They make viewing information from our debugger CLI very difficult. PiperOrigin-RevId: 866158204 --- .../main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 6dfbf586c..fd383baf1 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -698,13 +698,13 @@ private Event buildModelResponseEvent( Event event = eventBuilder.build(); - logger.info("event: {} functionCalls: {}", event, event.functionCalls()); + logger.debug("event: {} functionCalls: {}", event, event.functionCalls()); if (!event.functionCalls().isEmpty()) { Functions.populateClientFunctionCallId(event); Set longRunningToolIds = Functions.getLongRunningFunctionCalls(event.functionCalls(), llmRequest.tools()); - logger.info("longRunningToolIds: {}", longRunningToolIds); + logger.debug("longRunningToolIds: {}", longRunningToolIds); if (!longRunningToolIds.isEmpty()) { event.setLongRunningToolIds(Optional.of(longRunningToolIds)); } From 8190ed3d78667875ee0772e52b7075dcdaa14963 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 6 Feb 2026 07:11:41 -0800 Subject: [PATCH 31/63] fix: Fixing a problem with serializing sessions that broke integration with Vertex AI Session Service PiperOrigin-RevId: 866447373 --- .../adk/sessions/SessionJsonConverter.java | 179 ++++++++++-------- .../sessions/SessionJsonConverterTest.java | 160 ++++++++++++++++ 2 files changed, 258 insertions(+), 81 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 220d8f205..3781ef537 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; -import com.google.adk.agents.BaseAgentState; import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.events.ToolConfirmation; @@ -28,10 +27,12 @@ import com.google.common.collect.Iterables; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import java.io.UncheckedIOException; import java.time.Instant; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -57,63 +58,64 @@ private SessionJsonConverter() {} * @throws UncheckedIOException if serialization fails. */ static String convertEventToJson(Event event) { - Map metadataJson = new HashMap<>(); - metadataJson.put("partial", event.partial()); - metadataJson.put("turnComplete", event.turnComplete()); - metadataJson.put("interrupted", event.interrupted()); - metadataJson.put("branch", event.branch().orElse(null)); - metadataJson.put( - "long_running_tool_ids", - event.longRunningToolIds() != null ? event.longRunningToolIds().orElse(null) : null); - if (event.groundingMetadata() != null) { - metadataJson.put("grounding_metadata", event.groundingMetadata()); - } + return convertEventToJson(event, false); + } + /** + * Converts an {@link Event} to its JSON string representation for API transmission. + * + * @param useIsoString if true, use ISO-8601 string for timestamp; otherwise use object format. + * @return JSON string of the event. + * @throws UncheckedIOException if serialization fails. + */ + static String convertEventToJson(Event event, boolean useIsoString) { + Map metadataJson = new HashMap<>(); + event.partial().ifPresent(v -> metadataJson.put("partial", v)); + event.turnComplete().ifPresent(v -> metadataJson.put("turnComplete", v)); + event.interrupted().ifPresent(v -> metadataJson.put("interrupted", v)); + event.branch().ifPresent(v -> metadataJson.put("branch", v)); + putIfNotEmpty(metadataJson, "longRunningToolIds", event.longRunningToolIds()); + event.groundingMetadata().ifPresent(v -> metadataJson.put("groundingMetadata", v)); + event.usageMetadata().ifPresent(v -> metadataJson.put("usageMetadata", v)); Map eventJson = new HashMap<>(); eventJson.put("author", event.author()); eventJson.put("invocationId", event.invocationId()); - eventJson.put( - "timestamp", - new HashMap<>( - ImmutableMap.of( - "seconds", - event.timestamp() / 1000, - "nanos", - (event.timestamp() % 1000) * 1000000))); - if (event.errorCode().isPresent()) { - eventJson.put("errorCode", event.errorCode()); - } - if (event.errorMessage().isPresent()) { - eventJson.put("errorMessage", event.errorMessage()); + if (useIsoString) { + eventJson.put("timestamp", Instant.ofEpochMilli(event.timestamp()).toString()); + } else { + eventJson.put( + "timestamp", + new HashMap<>( + ImmutableMap.of( + "seconds", + event.timestamp() / 1000, + "nanos", + (event.timestamp() % 1000) * 1000000))); } + event.errorCode().ifPresent(errorCode -> eventJson.put("errorCode", errorCode)); + event.errorMessage().ifPresent(errorMessage -> eventJson.put("errorMessage", errorMessage)); eventJson.put("eventMetadata", metadataJson); if (event.actions() != null) { Map actionsJson = new HashMap<>(); - actionsJson.put("skipSummarization", event.actions().skipSummarization()); - actionsJson.put("stateDelta", stateDeltaToJson(event.actions().stateDelta())); - actionsJson.put("artifactDelta", event.actions().artifactDelta()); - actionsJson.put("transferAgent", event.actions().transferToAgent()); - actionsJson.put("escalate", event.actions().escalate()); - actionsJson.put("endInvocation", event.actions().endInvocation()); - actionsJson.put("requestedAuthConfigs", event.actions().requestedAuthConfigs()); - actionsJson.put("requestedToolConfirmations", event.actions().requestedToolConfirmations()); - actionsJson.put("compaction", event.actions().compaction()); - if (!event.actions().agentState().isEmpty()) { - actionsJson.put("agentState", event.actions().agentState()); - } - actionsJson.put("rewindBeforeInvocationId", event.actions().rewindBeforeInvocationId()); + EventActions actions = event.actions(); + actions.skipSummarization().ifPresent(v -> actionsJson.put("skipSummarization", v)); + actionsJson.put("stateDelta", stateDeltaToJson(actions.stateDelta())); + putIfNotEmpty(actionsJson, "artifactDelta", actions.artifactDelta()); + actions + .transferToAgent() + .ifPresent( + v -> { + actionsJson.put("transferAgent", v); + }); + actions.escalate().ifPresent(v -> actionsJson.put("escalate", v)); + actions.endInvocation().ifPresent(v -> actionsJson.put("endOfAgent", v)); + putIfNotEmpty(actionsJson, "requestedAuthConfigs", actions.requestedAuthConfigs()); + putIfNotEmpty( + actionsJson, "requestedToolConfirmations", actions.requestedToolConfirmations()); eventJson.put("actions", actionsJson); } - if (event.content().isPresent()) { - eventJson.put("content", SessionUtils.encodeContent(event.content().get())); - } - if (event.errorCode().isPresent()) { - eventJson.put("errorCode", event.errorCode().get()); - } - if (event.errorMessage().isPresent()) { - eventJson.put("errorMessage", event.errorMessage().get()); - } + event.content().ifPresent(c -> eventJson.put("content", SessionUtils.encodeContent(c))); try { return objectMapper.writeValueAsString(eventJson); } catch (JsonProcessingException e) { @@ -156,19 +158,31 @@ private static Content convertMapToContent(Object rawContentValue) { @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. static Event fromApiEvent(Map apiEvent) { EventActions.Builder eventActionsBuilder = EventActions.builder(); - if (apiEvent.get("actions") != null) { - Map actionsMap = (Map) apiEvent.get("actions"); - if (actionsMap.get("skipSummarization") != null) { - eventActionsBuilder.skipSummarization((Boolean) actionsMap.get("skipSummarization")); + Map actionsMap = (Map) apiEvent.get("actions"); + if (actionsMap != null) { + Boolean skipSummarization = (Boolean) actionsMap.get("skipSummarization"); + if (skipSummarization != null) { + eventActionsBuilder.skipSummarization(skipSummarization); } eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta"))); + Object artifactDelta = actionsMap.get("artifactDelta"); eventActionsBuilder.artifactDelta( - actionsMap.get("artifactDelta") != null - ? convertToArtifactDeltaMap(actionsMap.get("artifactDelta")) + artifactDelta != null + ? convertToArtifactDeltaMap(artifactDelta) : new ConcurrentHashMap<>()); - eventActionsBuilder.transferToAgent((String) actionsMap.get("transferAgent")); - if (actionsMap.get("escalate") != null) { - eventActionsBuilder.escalate((Boolean) actionsMap.get("escalate")); + String transferAgent = (String) actionsMap.get("transferAgent"); + if (transferAgent == null) { + transferAgent = (String) actionsMap.get("transferToAgent"); + } + eventActionsBuilder.transferToAgent(transferAgent); + Boolean escalate = (Boolean) actionsMap.get("escalate"); + if (escalate != null) { + eventActionsBuilder.escalate(escalate); + } + Boolean endOfAgent = (Boolean) actionsMap.get("endOfAgent"); + if (endOfAgent != null) { + eventActionsBuilder.endOfAgent(endOfAgent); + eventActionsBuilder.endInvocation(endOfAgent); } eventActionsBuilder.requestedAuthConfigs( Optional.ofNullable(actionsMap.get("requestedAuthConfigs")) @@ -178,13 +192,6 @@ static Event fromApiEvent(Map apiEvent) { Optional.ofNullable(actionsMap.get("requestedToolConfirmations")) .map(SessionJsonConverter::asConcurrentMapOfToolConfirmations) .orElse(new ConcurrentHashMap<>())); - if (actionsMap.get("agentState") != null) { - eventActionsBuilder.agentState(asConcurrentMapOfAgentState(actionsMap.get("agentState"))); - } - if (actionsMap.get("rewindBeforeInvocationId") != null) { - eventActionsBuilder.rewindBeforeInvocationId( - (String) actionsMap.get("rewindBeforeInvocationId")); - } } Event event = @@ -204,11 +211,9 @@ static Event fromApiEvent(Map apiEvent) { .map(value -> new FinishReason((String) value))) .errorMessage( Optional.ofNullable(apiEvent.get("errorMessage")).map(value -> (String) value)) - .branch(Optional.ofNullable(apiEvent.get("branch")).map(value -> (String) value)) .build(); - // TODO(b/414263934): Add Event branch and grounding metadata for python parity. - if (apiEvent.get("eventMetadata") != null) { - Map eventMetadata = (Map) apiEvent.get("eventMetadata"); + Map eventMetadata = (Map) apiEvent.get("eventMetadata"); + if (eventMetadata != null) { List longRunningToolIdsList = (List) eventMetadata.get("longRunningToolIds"); GroundingMetadata groundingMetadata = null; @@ -217,6 +222,12 @@ static Event fromApiEvent(Map apiEvent) { groundingMetadata = objectMapper.convertValue(rawGroundingMetadata, GroundingMetadata.class); } + GenerateContentResponseUsageMetadata usageMetadata = null; + Object rawUsageMetadata = eventMetadata.get("usageMetadata"); + if (rawUsageMetadata != null) { + usageMetadata = + objectMapper.convertValue(rawUsageMetadata, GenerateContentResponseUsageMetadata.class); + } event = event.toBuilder() @@ -227,6 +238,7 @@ static Event fromApiEvent(Map apiEvent) { Optional.ofNullable((Boolean) eventMetadata.get("interrupted")).orElse(false)) .branch(Optional.ofNullable((String) eventMetadata.get("branch"))) .groundingMetadata(groundingMetadata) + .usageMetadata(usageMetadata) .longRunningToolIds( longRunningToolIdsList != null ? new HashSet<>(longRunningToolIdsList) : null) .build(); @@ -285,7 +297,7 @@ private static Instant convertToInstant(Object timestampObj) { * @param artifactDeltaObj The raw object from which to parse the artifact delta. * @return A {@link ConcurrentMap} representing the artifact delta. */ - @SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting. + @SuppressWarnings("unchecked") private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { return new ConcurrentHashMap<>(); @@ -319,19 +331,6 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti ConcurrentHashMap::putAll); } - @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. - private static ConcurrentMap asConcurrentMapOfAgentState(Object value) { - return ((Map) value) - .entrySet().stream() - .collect( - ConcurrentHashMap::new, - (map, entry) -> - map.put( - entry.getKey(), - objectMapper.convertValue(entry.getValue(), BaseAgentState.class)), - ConcurrentHashMap::putAll); - } - @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap asConcurrentMapOfToolConfirmations( Object value) { @@ -345,4 +344,22 @@ private static ConcurrentMap asConcurrentMapOfToolConf objectMapper.convertValue(entry.getValue(), ToolConfirmation.class)), ConcurrentHashMap::putAll); } + + private static void putIfNotEmpty(Map map, String key, Map values) { + if (values != null && !values.isEmpty()) { + map.put(key, values); + } + } + + private static void putIfNotEmpty( + Map map, String key, Optional> values) { + values.ifPresent(v -> putIfNotEmpty(map, key, v)); + } + + private static void putIfNotEmpty( + Map map, String key, @Nullable Collection values) { + if (values != null && !values.isEmpty()) { + map.put(key, values); + } + } } diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index 827e810aa..3d1a845ed 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -1,6 +1,7 @@ package com.google.adk.sessions; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -8,9 +9,14 @@ import com.google.adk.JsonBaseModel; import com.google.adk.events.Event; import com.google.adk.events.EventActions; +import com.google.adk.events.ToolConfirmation; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import java.time.Instant; import java.util.Collections; @@ -18,6 +24,7 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -152,6 +159,112 @@ public void fromApiEvent_fullEvent_success() { assertThat(eventActions.escalate()).hasValue(true); } + @Test + public void fromApiEvent_withTransferToAgent_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map actions = new HashMap<>(); + actions.put("transferToAgent", "agent-id"); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.actions().transferToAgent()).hasValue("agent-id"); + } + + @Test + public void convertEventToJson_complexActions_success() throws JsonProcessingException { + ConcurrentMap> authConfigs = new ConcurrentHashMap<>(); + authConfigs.put("auth1", new ConcurrentHashMap<>(ImmutableMap.of("param1", "value1"))); + + ConcurrentMap toolConfirmations = new ConcurrentHashMap<>(); + toolConfirmations.put( + "tool1", ToolConfirmation.builder().hint("hint1").confirmed(true).build()); + + EventActions actions = + EventActions.builder() + .requestedAuthConfigs(authConfigs) + .requestedToolConfirmations(toolConfirmations) + .endInvocation(true) + .build(); + + GenerateContentResponseUsageMetadata usageMetadata = + GenerateContentResponseUsageMetadata.builder().promptTokenCount(10).build(); + GroundingMetadata groundingMetadata = GroundingMetadata.builder().build(); + + Event event = + Event.builder() + .author("user") + .invocationId("inv-123") + .timestamp(Instant.parse("2023-01-01T00:00:00.123Z").toEpochMilli()) + .actions(actions) + .longRunningToolIds(ImmutableSet.of("tool-id-1")) + .usageMetadata(usageMetadata) + .groundingMetadata(groundingMetadata) + .build(); + + String json = SessionJsonConverter.convertEventToJson(event, true); + JsonNode jsonNode = objectMapper.readTree(json); + + assertThat(jsonNode.get("timestamp").asText()).isEqualTo("2023-01-01T00:00:00.123Z"); + + JsonNode eventMetadata = jsonNode.get("eventMetadata"); + assertThat(eventMetadata.get("longRunningToolIds").get(0).asText()).isEqualTo("tool-id-1"); + assertThat(eventMetadata.has("usageMetadata")).isTrue(); + assertThat(eventMetadata.has("groundingMetadata")).isTrue(); + + JsonNode actionsNode = jsonNode.get("actions"); + assertThat(actionsNode.get("requestedAuthConfigs").get("auth1").get("param1").asText()) + .isEqualTo("value1"); + assertThat(actionsNode.get("requestedToolConfirmations").get("tool1").get("hint").asText()) + .isEqualTo("hint1"); + assertThat( + actionsNode.get("requestedToolConfirmations").get("tool1").get("confirmed").asBoolean()) + .isTrue(); + assertThat(actionsNode.get("endOfAgent").asBoolean()).isTrue(); + } + + @Test + public void fromApiEvent_complexActions_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00.123Z"); + + Map actions = new HashMap<>(); + actions.put("requestedAuthConfigs", ImmutableMap.of("auth1", ImmutableMap.of("p1", "v1"))); + actions.put( + "requestedToolConfirmations", + ImmutableMap.of("tool1", ImmutableMap.of("hint", "h1", "confirmed", true))); + actions.put("endOfAgent", true); + apiEvent.put("actions", actions); + + Map eventMetadata = new HashMap<>(); + eventMetadata.put("longRunningToolIds", ImmutableList.of("tool-1")); + eventMetadata.put("usageMetadata", ImmutableMap.of("promptTokenCount", 10)); + eventMetadata.put("groundingMetadata", ImmutableMap.of()); + apiEvent.put("eventMetadata", eventMetadata); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.timestamp()) + .isEqualTo(Instant.parse("2023-01-01T00:00:00.123Z").toEpochMilli()); + assertThat(event.longRunningToolIds().get()).containsExactly("tool-1"); + assertThat(event.usageMetadata().get().promptTokenCount()).hasValue(10); + assertThat(event.groundingMetadata()).isPresent(); + + EventActions eventActions = event.actions(); + assertThat(eventActions.requestedAuthConfigs().get("auth1")).containsEntry("p1", "v1"); + assertThat(eventActions.requestedToolConfirmations().get("tool1").hint()).isEqualTo("h1"); + assertThat(eventActions.requestedToolConfirmations().get("tool1").confirmed()).isTrue(); + assertThat(eventActions.endOfAgent()).isTrue(); + } + @Test public void fromApiEvent_minimalEvent_success() { Map apiEvent = new HashMap<>(); @@ -246,6 +359,53 @@ public void convertEventToJson_withStateRemoved_success() throws JsonProcessingE assertThat(actionsNode.get("stateDelta").get("key2").isNull()).isTrue(); } + @Test + public void fromApiEvent_withInvalidContentMap_returnsNullContent() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + // Parts should be a list, not a string + apiEvent.put("content", ImmutableMap.of("parts", "invalid")); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.content()).isEmpty(); + } + + @Test + public void fromApiEvent_withInvalidArtifactDelta_skipsInvalidEntries() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map artifactDelta = new HashMap<>(); + artifactDelta.put("valid", ImmutableMap.of("text", "valid_text")); + artifactDelta.put("invalid", "not-a-map"); + + Map actions = new HashMap<>(); + actions.put("artifactDelta", artifactDelta); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.actions().artifactDelta()).containsKey("valid"); + assertThat(event.actions().artifactDelta()).doesNotContainKey("invalid"); + } + + @Test + public void fromApiEvent_missingTimestamp_throwsException() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + + assertThrows(IllegalArgumentException.class, () -> SessionJsonConverter.fromApiEvent(apiEvent)); + } + @Test public void fromApiEvent_withNullStateDeltaValue_success() { Map apiEvent = new HashMap<>(); From 101adce314dd65328af6ad9281afb46f9b160c1a Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 6 Feb 2026 07:44:45 -0800 Subject: [PATCH 32/63] fix: revert: Merging of events in rearrangeEventsForAsyncFunctionResponsesInHistory PiperOrigin-RevId: 866457410 --- .../google/adk/flows/llmflows/Contents.java | 3 +- .../adk/flows/llmflows/ContentsTest.java | 72 ++----------------- 2 files changed, 8 insertions(+), 67 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index f45461626..171dab972 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -564,7 +564,8 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int i = 0; i < events.size(); i++) { Event event = events.get(i); - if (!event.functionResponses().isEmpty()) { + // Skip response events that will be processed via responseEventsBuffer + if (processedResponseIndices.contains(i)) { continue; } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index a8a862b51..82a57ed4f 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -203,11 +203,9 @@ public void rearrangeHistory_asyncFR_returnsRearrangedList() { public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { Event fcEvent = createFunctionCallEvent("fc1", "tool1", "call1"); Event frEvent1 = - createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "running")); Event frEvent2 = - createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "running")); - Event frEvent3 = - createFunctionResponseEvent("fr3", "tool1", "call1", ImmutableMap.of("status", "done")); + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); ImmutableList inputEvents = ImmutableList.of( createUserEvent("u1", "Query"), @@ -215,75 +213,17 @@ public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { createUserEvent("u2", "Wait"), frEvent1, createUserEvent("u3", "Done?"), - frEvent2, - frEvent3, - createUserEvent("u4", "Follow up query")); + frEvent2); List result = runContentsProcessor(inputEvents); - assertThat(result).hasSize(6); // u1, fc1, merged_fr, u2, u3, u4 + assertThat(result).hasSize(3); // u1, fc1, merged_fr assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); - assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check fcEvent - Content mergedContent = result.get(2); - assertThat(mergedContent.parts().get()).hasSize(1); - assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) - .containsExactly("status", "done"); // Last FR wins (frEvent3) - assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 - assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 - assertThat(result.get(5)).isEqualTo(inputEvents.get(7).content().get()); // u4 - } - - @Test - public void rearrangeHistory_multipleFRsForMultipleFC_returnsMergedFR() { - Event fcEvent1 = createFunctionCallEvent("fc1", "tool1", "call1"); - Event fcEvent2 = createFunctionCallEvent("fc2", "tool1", "call2"); - - Event frEvent1 = - createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); - Event frEvent2 = - createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); - - Event frEvent3 = - createFunctionResponseEvent("fr3", "tool1", "call2", ImmutableMap.of("status", "pending")); - Event frEvent4 = - createFunctionResponseEvent("fr4", "tool1", "call2", ImmutableMap.of("status", "done")); - - ImmutableList inputEvents = - ImmutableList.of( - createUserEvent("u1", "I"), - fcEvent1, - createUserEvent("u2", "am"), - frEvent1, - createUserEvent("u3", "waiting"), - frEvent2, - createUserEvent("u4", "for"), - fcEvent2, - createUserEvent("u5", "you"), - frEvent3, - createUserEvent("u6", "to"), - frEvent4, - createUserEvent("u7", "Follow up query")); - - List result = runContentsProcessor(inputEvents); - - assertThat(result).hasSize(11); // u1, fc1, frEvent2, u2, u3, u4, fc2, frEvent4, u5, u6, u7 - assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); // u1 - assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // fc1 + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check merged event Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(1); assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) - .containsExactly("status", "done"); // Last FR wins (frEvent2) - assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 - assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 - assertThat(result.get(5)).isEqualTo(inputEvents.get(6).content().get()); // u4 - assertThat(result.get(6)).isEqualTo(inputEvents.get(7).content().get()); // fc2 - Content mergedContent2 = result.get(7); - assertThat(mergedContent2.parts().get()).hasSize(1); - assertThat(mergedContent2.parts().get().get(0).functionResponse().get().response().get()) - .containsExactly("status", "done"); // Last FR wins (frEvent4) - assertThat(result.get(8)).isEqualTo(inputEvents.get(8).content().get()); // u5 - assertThat(result.get(9)).isEqualTo(inputEvents.get(10).content().get()); // u6 - assertThat(result.get(10)).isEqualTo(inputEvents.get(12).content().get()); // u7 + .containsExactly("status", "done"); // Last FR wins } @Test From ded5a4e760055d3d2bcd74d3bd8f21517821e7d0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 6 Feb 2026 08:59:07 -0800 Subject: [PATCH 33/63] Fix: Mutate EventActions in-place in AgentTool PiperOrigin-RevId: 866482983 --- .../java/com/google/adk/tools/AgentTool.java | 3 +- .../com/google/adk/tools/AgentToolTest.java | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 2a50605a1..7eabc48c4 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -138,7 +138,8 @@ public Optional declaration() { public Single> runAsync(Map args, ToolContext toolContext) { if (this.skipSummarization) { - toolContext.setActions(toolContext.actions().toBuilder().skipSummarization(true).build()); + // Mutate EventActions in-place to ensure object references are maintained. + toolContext.actions().setSkipSummarization(true); } Optional agentInputSchema = getInputSchema(agent); diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index c961e654a..2fc17b94d 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -453,6 +453,39 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio assertThat(toolContext.state()).containsEntry("test_key", "test_value"); } + @Test + public void call_withSkipSummarizationAndStateDelta_propagatesStateAndSetsSkipSummarization() + throws Exception { + AfterAgentCallback afterAgentCallback = + (callbackContext) -> { + callbackContext.state().put("test_key", "test_value"); + return Maybe.empty(); + }; + TestLlm testLlm = + createTestLlm( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText("test response"))) + .build()); + LlmAgent testAgent = + createTestAgentBuilder(testLlm) + .name("agent name") + .description("agent description") + .afterAgentCallback(afterAgentCallback) + .build(); + AgentTool agentTool = AgentTool.create(testAgent, /* skipSummarization= */ true); + ToolContext toolContext = createToolContext(testAgent); + + assertThat(toolContext.state()).doesNotContainKey("test_key"); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + // Verify that stateDelta is propagated to the ToolContext's EventActions, otherwise + // Function.buildResponseEvent() will not include it in the response event. + assertThat(toolContext.actions().stateDelta()).containsEntry("test_key", "test_value"); + assertThat(toolContext.actions().skipSummarization()).hasValue(true); + } + @Test public void declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() { From ed736cdf84d8db92dfde947b5ee84e7430f3ae6d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 6 Feb 2026 10:22:54 -0800 Subject: [PATCH 34/63] feat: adding a new temporary store of context for callbacks Also, expanding testing on InvocationContext. PiperOrigin-RevId: 866515259 --- .../google/adk/agents/InvocationContext.java | 30 +- .../adk/agents/InvocationContextTest.java | 329 ++++++++++++++++++ .../adk/flows/llmflows/BaseLlmFlowTest.java | 78 +++++ 3 files changed, 435 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index ed9b21062..d197dbaa5 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -56,6 +56,7 @@ public class InvocationContext { private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; private final InvocationCostManager invocationCostManager; + private final Map callbackContextData; private Optional branch; private BaseAgent agent; @@ -80,6 +81,7 @@ protected InvocationContext(Builder builder) { this.resumabilityConfig = builder.resumabilityConfig; this.eventsCompactionConfig = builder.eventsCompactionConfig; this.invocationCostManager = builder.invocationCostManager; + this.callbackContextData = builder.callbackContextData; } /** @@ -306,6 +308,14 @@ public RunConfig runConfig() { return runConfig; } + /** + * Returns a map for storing temporary context data that can be shared between different parts of + * the invocation (e.g., before/on/after model callbacks). + */ + public Map callbackContextData() { + return callbackContextData; + } + /** Returns agent-specific state saved within this invocation. */ public Map agentStates() { return agentStates; @@ -437,6 +447,7 @@ private Builder(InvocationContext context) { this.resumabilityConfig = context.resumabilityConfig; this.eventsCompactionConfig = context.eventsCompactionConfig; this.invocationCostManager = context.invocationCostManager; + this.callbackContextData = context.callbackContextData; } private BaseSessionService sessionService; @@ -457,6 +468,7 @@ private Builder(InvocationContext context) { private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); @Nullable private EventsCompactionConfig eventsCompactionConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); + private Map callbackContextData = new ConcurrentHashMap<>(); /** * Sets the session service for managing session state. @@ -692,6 +704,18 @@ public Builder eventsCompactionConfig(@Nullable EventsCompactionConfig eventsCom return this; } + /** + * Sets the callback context data for the invocation. + * + * @param callbackContextData the callback context data. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder callbackContextData(Map callbackContextData) { + this.callbackContextData = callbackContextData; + return this; + } + /** * Builds the {@link InvocationContext} instance. * @@ -728,7 +752,8 @@ public boolean equals(Object o) { && Objects.equals(endOfAgents, that.endOfAgents) && Objects.equals(resumabilityConfig, that.resumabilityConfig) && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) - && Objects.equals(invocationCostManager, that.invocationCostManager); + && Objects.equals(invocationCostManager, that.invocationCostManager) + && Objects.equals(callbackContextData, that.callbackContextData); } @Override @@ -751,6 +776,7 @@ public int hashCode() { endOfAgents, resumabilityConfig, eventsCompactionConfig, - invocationCostManager); + invocationCostManager, + callbackContextData); } } diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 64d2f5bf6..61135c78e 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -23,10 +23,13 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.memory.BaseMemoryService; +import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.adk.summarizer.EventsCompactionConfig; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -34,6 +37,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -178,6 +183,25 @@ public void testCopyOf() { assertThat(copiedContext.endInvocation()).isEqualTo(originalContext.endInvocation()); assertThat(copiedContext.activeStreamingTools()) .isEqualTo(originalContext.activeStreamingTools()); + assertThat(copiedContext.callbackContextData()) + .isSameInstanceAs(originalContext.callbackContextData()); + } + + @Test + public void testBuildWithCallbackContextData() { + Map data = new ConcurrentHashMap<>(); + data.put("key", "value"); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .callbackContextData(data) + .build(); + + assertThat(context.callbackContextData()).isEqualTo(data); + assertThat(context.callbackContextData()).isSameInstanceAs(data); } @Test @@ -404,6 +428,22 @@ public void testEquals_differentValues() { assertThat(context.equals(contextWithDiffAgent)).isFalse(); assertThat(context.equals(contextWithUserContentEmpty)).isFalse(); assertThat(context.equals(contextWithLiveQueuePresent)).isFalse(); + + InvocationContext contextWithDiffCallbackContextData = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(false) + .callbackContextData(ImmutableMap.of("key", "value")) + .build(); + assertThat(context.equals(contextWithDiffCallbackContextData)).isFalse(); } @Test @@ -453,6 +493,22 @@ public void testHashCode_differentValues() { assertThat(context).isNotEqualTo(contextWithDiffSessionService); assertThat(context).isNotEqualTo(contextWithDiffInvocationId); + + InvocationContext contextWithDiffCallbackContextData = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(false) + .callbackContextData(ImmutableMap.of("key", "value")) + .build(); + assertThat(context.hashCode()).isNotEqualTo(contextWithDiffCallbackContextData.hashCode()); } @Test @@ -604,4 +660,277 @@ public void shouldPauseInvocation_whenResumableAndMatchingFunctionCallId_isTrue( .build(); assertThat(context.shouldPauseInvocation(event)).isTrue(); } + + @Test + public void incrementLlmCallsCount_whenLimitNotExceeded_doesNotThrow() throws Exception { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .runConfig(RunConfig.builder().setMaxLlmCalls(2).build()) + .build(); + + context.incrementLlmCallsCount(); + context.incrementLlmCallsCount(); + // No exception thrown + } + + @Test + public void incrementLlmCallsCount_whenLimitExceeded_throwsException() throws Exception { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .runConfig(RunConfig.builder().setMaxLlmCalls(1).build()) + .build(); + + context.incrementLlmCallsCount(); + LlmCallsLimitExceededException thrown = + Assert.assertThrows( + LlmCallsLimitExceededException.class, () -> context.incrementLlmCallsCount()); + assertThat(thrown).hasMessageThat().contains("limit of 1 exceeded"); + } + + @Test + public void incrementLlmCallsCount_whenNoLimit_doesNotThrow() throws Exception { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .runConfig(RunConfig.builder().setMaxLlmCalls(0).build()) + .build(); + + for (int i = 0; i < 100; i++) { + context.incrementLlmCallsCount(); + } + } + + @Test + public void testSessionGetters() { + Session sessionWithDetails = + Session.builder("test-id").appName("test-app").userId("test-user").build(); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(sessionWithDetails) + .build(); + + assertThat(context.appName()).isEqualTo("test-app"); + assertThat(context.userId()).isEqualTo("test-user"); + } + + @Test + public void testAgentStatesAndEndOfAgents() { + BaseAgentState mockState = mock(BaseAgentState.class); + ImmutableMap states = ImmutableMap.of("agent1", mockState); + ImmutableMap endOfAgents = ImmutableMap.of("agent1", true); + + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .agentStates(states) + .endOfAgents(endOfAgents) + .build(); + + assertThat(context.agentStates()).isEqualTo(states); + assertThat(context.endOfAgents()).isEqualTo(endOfAgents); + } + + @Test + public void testSetEndInvocation() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .build(); + + assertThat(context.endInvocation()).isFalse(); + context.setEndInvocation(true); + assertThat(context.endInvocation()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testBranch() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .branch("test-branch") + .build(); + + assertThat(context.branch()).hasValue("test-branch"); + + context.branch("new-branch"); + assertThat(context.branch()).hasValue("new-branch"); + + context.branch(null); + assertThat(context.branch()).isEmpty(); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testDeprecatedCreateMethods() { + InvocationContext context1 = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.ofNullable(userContent)) + .runConfig(runConfig) + .build(); + + assertThat(context1.sessionService()).isEqualTo(mockSessionService); + assertThat(context1.artifactService()).isEqualTo(mockArtifactService); + assertThat(context1.invocationId()).isEqualTo(testInvocationId); + assertThat(context1.agent()).isEqualTo(mockAgent); + assertThat(context1.session()).isEqualTo(session); + assertThat(context1.userContent()).hasValue(userContent); + assertThat(context1.runConfig()).isEqualTo(runConfig); + + InvocationContext context2 = + InvocationContext.create( + mockSessionService, + mockArtifactService, + mockAgent, + session, + liveRequestQueue, + runConfig); + + assertThat(context2.sessionService()).isEqualTo(mockSessionService); + assertThat(context2.artifactService()).isEqualTo(mockArtifactService); + assertThat(context2.agent()).isEqualTo(mockAgent); + assertThat(context2.session()).isEqualTo(session); + assertThat(context2.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context2.runConfig()).isEqualTo(runConfig); + } + + @Test + public void testActiveStreamingTools() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .build(); + + assertThat(context.activeStreamingTools()).isEmpty(); + ActiveStreamingTool tool = new ActiveStreamingTool(new LiveRequestQueue()); + context.activeStreamingTools().put("tool1", tool); + assertThat(context.activeStreamingTools()).containsEntry("tool1", tool); + } + + @Test + public void testEventsCompactionConfig() { + EventsCompactionConfig config = new EventsCompactionConfig(5, 2); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .eventsCompactionConfig(config) + .build(); + + assertThat(context.eventsCompactionConfig()).hasValue(config); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testBuilderOptionalParameters() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .liveRequestQueue(Optional.of(liveRequestQueue)) + .branch(Optional.of("test-branch")) + .userContent(Optional.of(userContent)) + .build(); + + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.userContent()).hasValue(userContent); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testDeprecatedConstructor() { + InvocationContext context = + new InvocationContext( + mockSessionService, + mockArtifactService, + mockMemoryService, + pluginManager, + Optional.of(liveRequestQueue), + Optional.of("test-branch"), + testInvocationId, + mockAgent, + session, + Optional.of(userContent), + runConfig, + true); + + assertThat(context.sessionService()).isEqualTo(mockSessionService); + assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); + assertThat(context.pluginManager()).isEqualTo(pluginManager); + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.invocationId()).isEqualTo(testInvocationId); + assertThat(context.agent()).isEqualTo(mockAgent); + assertThat(context.session()).isEqualTo(session); + assertThat(context.userContent()).hasValue(userContent); + assertThat(context.runConfig()).isEqualTo(runConfig); + assertThat(context.endInvocation()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testDeprecatedConstructor_11params() { + InvocationContext context = + new InvocationContext( + mockSessionService, + mockArtifactService, + mockMemoryService, + Optional.of(liveRequestQueue), + Optional.of("test-branch"), + testInvocationId, + mockAgent, + session, + Optional.of(userContent), + runConfig, + true); + + assertThat(context.sessionService()).isEqualTo(mockSessionService); + assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.invocationId()).isEqualTo(testInvocationId); + assertThat(context.agent()).isEqualTo(mockAgent); + assertThat(context.session()).isEqualTo(session); + assertThat(context.userContent()).hasValue(userContent); + assertThat(context.runConfig()).isEqualTo(runConfig); + assertThat(context.endInvocation()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 5f4932a89..657d1c670 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -25,6 +25,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; +import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; @@ -42,6 +43,7 @@ import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Map; @@ -414,6 +416,82 @@ public void run_requestProcessorsAreCalledExactlyOnce() { assertThat(processor2CallCount.get()).isEqualTo(1); } + @Test + public void run_sharingcallbackContextDataBetweenCallbacks() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + ctx.invocationContext().callbackContextData().put("key", "value_from_before"); + return Maybe.empty(); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + String value = (String) ctx.invocationContext().callbackContextData().get("key"); + LlmResponse modifiedResp = + resp.toBuilder().content(Content.fromParts(Part.fromText("Saw: " + value))).build(); + return Maybe.just(modifiedResp); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .build()); + + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).stringifyContent()).isEqualTo("Saw: value_from_before"); + } + + @Test + public void run_sharingcallbackContextDataAcrossContextCopies() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + ctx.invocationContext().callbackContextData().put("key", "value_from_before"); + return Maybe.empty(); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + String value = (String) ctx.invocationContext().callbackContextData().get("key"); + LlmResponse modifiedResp = + resp.toBuilder().content(Content.fromParts(Part.fromText("Saw: " + value))).build(); + return Maybe.just(modifiedResp); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + new BaseLlmFlow(ImmutableList.of(), ImmutableList.of()) { + @Override + public Flowable run(InvocationContext context) { + // Force a context copy + InvocationContext copiedContext = context.toBuilder().build(); + return super.run(copiedContext); + } + }; + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).stringifyContent()).isEqualTo("Saw: value_from_before"); + } + private static BaseLlmFlow createBaseLlmFlowWithoutProcessors() { return createBaseLlmFlow(ImmutableList.of(), ImmutableList.of()); } From 2de03a86f97eb602dee55270b910d0d425ae75e9 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 6 Feb 2026 14:16:40 -0800 Subject: [PATCH 35/63] feat: adding resume / event management primitives This is a step towards implementing pause/resume/rewind. This change introduce several features related to resumability and event management within the Google ADK core. Here's a summary of the changes: 1. **`InvocationContext.java`**: * A new public method `resumabilityConfig()` is added to provide access to the invocation's `ResumabilityConfig`. * A method `populateAgentStates(ImmutableList events)` is introduced to initialize or update the `agentStates` and `endOfAgents` maps within the `InvocationContext` by processing events associated with the current invocation ID. 2. **`EventActions.java`**: * The `EventActions` class now extends `JsonBaseModel`. * A new field `deletedArtifactIds` (a `Set`) is added to track artifacts that should be deleted. This field is included in JSON serialization/deserialization, equality checks, and the `EventActions.Builder`'s merge logic. 3. **`Event.java`**: * The `finalResponse()` logic is updated. Previously, an event with `longRunningToolIds` was always considered a final response. This check has been removed, meaning the presence of `longRunningToolIds` alone no longer makes an event a `finalResponse`. PiperOrigin-RevId: 866611243 --- .../google/adk/agents/InvocationContext.java | 26 ++++++++ .../java/com/google/adk/events/Event.java | 3 +- .../com/google/adk/events/EventActions.java | 31 ++++++++- .../adk/agents/InvocationContextTest.java | 55 +++++++++++++++- .../google/adk/events/EventActionsTest.java | 25 +++++++- .../java/com/google/adk/events/EventTest.java | 63 +++++++++++++++++++ 6 files changed, 198 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index d197dbaa5..3b460b073 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -31,6 +31,7 @@ import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -369,6 +370,31 @@ public boolean isResumable() { return resumabilityConfig.isResumable(); } + /** Returns ResumabilityConfig for this invocation. */ + public ResumabilityConfig resumabilityConfig() { + return resumabilityConfig; + } + + /** + * Populates agentStates and endOfAgents maps by reading session events for this invocation id. + */ + public void populateAgentStates(List events) { + events.stream() + .filter(event -> invocationId().equals(event.invocationId())) + .forEach( + event -> { + if (event.actions() != null) { + if (event.actions().agentState() != null + && !event.actions().agentState().isEmpty()) { + agentStates.putAll(event.actions().agentState()); + } + if (event.actions().endOfAgent()) { + endOfAgents.put(event.author(), true); + } + } + }); + } + /** Returns the events compaction configuration for the current agent run. */ public Optional eventsCompactionConfig() { return Optional.ofNullable(eventsCompactionConfig); diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index 9e05918be..d968efa53 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -294,8 +294,7 @@ public final boolean hasTrailingCodeExecutionResult() { /** Returns true if this is a final response. */ @JsonIgnore public final boolean finalResponse() { - if (actions().skipSummarization().orElse(false) - || (longRunningToolIds().isPresent() && !longRunningToolIds().get().isEmpty())) { + if (actions().skipSummarization().orElse(false)) { return true; } return functionCalls().isEmpty() diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 493fa4b27..6543ec823 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -18,12 +18,15 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.adk.JsonBaseModel; import com.google.adk.agents.BaseAgentState; import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Part; +import java.util.HashSet; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -31,11 +34,12 @@ /** Represents the actions attached to an event. */ // TODO - b/414081262 make json wire camelCase @JsonDeserialize(builder = EventActions.Builder.class) -public class EventActions { +public class EventActions extends JsonBaseModel { private Optional skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; + private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; private ConcurrentMap> requestedAuthConfigs; @@ -51,6 +55,7 @@ public EventActions() { this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); + this.deletedArtifactIds = new HashSet<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); @@ -66,6 +71,7 @@ private EventActions(Builder builder) { this.skipSummarization = builder.skipSummarization; this.stateDelta = builder.stateDelta; this.artifactDelta = builder.artifactDelta; + this.deletedArtifactIds = builder.deletedArtifactIds; this.transferToAgent = builder.transferToAgent; this.escalate = builder.escalate; this.requestedAuthConfigs = builder.requestedAuthConfigs; @@ -122,6 +128,16 @@ public void setArtifactDelta(ConcurrentMap artifactDelta) { this.artifactDelta = artifactDelta; } + @JsonProperty("deletedArtifactIds") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public Set deletedArtifactIds() { + return deletedArtifactIds; + } + + public void setDeletedArtifactIds(Set deletedArtifactIds) { + this.deletedArtifactIds = deletedArtifactIds; + } + @JsonProperty("transferToAgent") public Optional transferToAgent() { return transferToAgent; @@ -238,6 +254,7 @@ public boolean equals(Object o) { return Objects.equals(skipSummarization, that.skipSummarization) && Objects.equals(stateDelta, that.stateDelta) && Objects.equals(artifactDelta, that.artifactDelta) + && Objects.equals(deletedArtifactIds, that.deletedArtifactIds) && Objects.equals(transferToAgent, that.transferToAgent) && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) @@ -255,6 +272,7 @@ public int hashCode() { skipSummarization, stateDelta, artifactDelta, + deletedArtifactIds, transferToAgent, escalate, requestedAuthConfigs, @@ -271,6 +289,7 @@ public static class Builder { private Optional skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; + private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; private ConcurrentMap> requestedAuthConfigs; @@ -285,6 +304,7 @@ public Builder() { this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); + this.deletedArtifactIds = new HashSet<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); @@ -299,6 +319,7 @@ private Builder(EventActions eventActions) { this.skipSummarization = eventActions.skipSummarization(); this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); + this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); this.transferToAgent = eventActions.transferToAgent(); this.escalate = eventActions.escalate(); this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); @@ -332,6 +353,13 @@ public Builder artifactDelta(ConcurrentMap value) { return this; } + @CanIgnoreReturnValue + @JsonProperty("deletedArtifactIds") + public Builder deletedArtifactIds(Set value) { + this.deletedArtifactIds = value; + return this; + } + @CanIgnoreReturnValue @JsonProperty("transferToAgent") public Builder transferToAgent(String agentId) { @@ -401,6 +429,7 @@ public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); this.stateDelta.putAll(other.stateDelta()); this.artifactDelta.putAll(other.artifactDelta()); + this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); other.escalate().ifPresent(this::escalate); this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 61135c78e..c1cb30180 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -22,6 +22,7 @@ import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; @@ -150,7 +151,7 @@ public void testBuildWithLiveRequestQueue() { } @Test - public void testCopyOf() { + public void testToBuilder() { InvocationContext originalContext = InvocationContext.builder() .sessionService(mockSessionService) @@ -933,4 +934,56 @@ public void testDeprecatedConstructor_11params() { assertThat(context.runConfig()).isEqualTo(runConfig); assertThat(context.endInvocation()).isTrue(); } + + @Test + public void populateAgentStates_populatesAgentStatesAndEndOfAgents() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .invocationId(testInvocationId) + .build(); + + BaseAgentState agent1State = mock(BaseAgentState.class); + ConcurrentHashMap agent1StateMap = new ConcurrentHashMap<>(); + agent1StateMap.put("agent1", agent1State); + Event event1 = + Event.builder() + .invocationId(testInvocationId) + .author("agent1") + .actions(EventActions.builder().agentState(agent1StateMap).endOfAgent(true).build()) + .build(); + Event event2 = + Event.builder() + .invocationId("other-invocation-id") + .author("agent2") + .actions(EventActions.builder().endOfAgent(true).build()) + .build(); + Event event3 = + Event.builder() + .invocationId(testInvocationId) + .author("agent3") + .actions(EventActions.builder().endOfAgent(false).build()) + .build(); + BaseAgentState agent4State = mock(BaseAgentState.class); + ConcurrentHashMap agent4StateMap = new ConcurrentHashMap<>(); + agent4StateMap.put("agent4", agent4State); + Event event4 = + Event.builder() + .invocationId(testInvocationId) + .author("agent4") + .actions(EventActions.builder().agentState(agent4StateMap).endOfAgent(false).build()) + .build(); + Event event5 = Event.builder().invocationId(testInvocationId).author("agent5").build(); + + context.populateAgentStates(ImmutableList.of(event1, event2, event3, event4, event5)); + + assertThat(context.agentStates()).hasSize(2); + assertThat(context.agentStates()).containsEntry("agent1", agent1State); + assertThat(context.agentStates()).containsEntry("agent4", agent4State); + assertThat(context.endOfAgents()).hasSize(1); + assertThat(context.endOfAgents()).containsEntry("agent1", true); + } } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 18870ad44..9ea88b40a 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -20,6 +20,7 @@ import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; import java.util.concurrent.ConcurrentHashMap; @@ -44,7 +45,11 @@ public final class EventActionsTest { @Test public void toBuilder_createsBuilderWithSameValues() { EventActions eventActionsWithSkipSummarization = - EventActions.builder().skipSummarization(true).compaction(COMPACTION).build(); + EventActions.builder() + .skipSummarization(true) + .compaction(COMPACTION) + .deletedArtifactIds(ImmutableSet.of("d1")) + .build(); EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build(); @@ -59,6 +64,7 @@ public void merge_mergesAllFields() { .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1"))) .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART))) + .deletedArtifactIds(ImmutableSet.of("deleted1")) .requestedAuthConfigs( new ConcurrentHashMap<>( ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) @@ -70,6 +76,7 @@ public void merge_mergesAllFields() { EventActions.builder() .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2"))) .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART))) + .deletedArtifactIds(ImmutableSet.of("deleted2")) .transferToAgent("agentId") .escalate(true) .requestedAuthConfigs( @@ -85,6 +92,7 @@ public void merge_mergesAllFields() { assertThat(merged.skipSummarization()).hasValue(true); assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2"); assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART); + assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2"); assertThat(merged.transferToAgent()).hasValue("agentId"); assertThat(merged.escalate()).hasValue(true); assertThat(merged.requestedAuthConfigs()) @@ -107,4 +115,19 @@ public void removeStateByKey_marksKeyAsRemoved() { assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); } + + @Test + public void jsonSerialization_works() throws Exception { + EventActions eventActions = + EventActions.builder() + .deletedArtifactIds(ImmutableSet.of("d1", "d2")) + .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))) + .build(); + + String json = eventActions.toJson(); + EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + + assertThat(deserialized).isEqualTo(eventActions); + assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); + } } diff --git a/core/src/test/java/com/google/adk/events/EventTest.java b/core/src/test/java/com/google/adk/events/EventTest.java index f443abee5..d6de97f7f 100644 --- a/core/src/test/java/com/google/adk/events/EventTest.java +++ b/core/src/test/java/com/google/adk/events/EventTest.java @@ -191,4 +191,67 @@ public void event_json_serialization_works() throws Exception { Event deserializedEvent = Event.fromJson(json); assertThat(deserializedEvent).isEqualTo(EVENT); } + + @Test + public void finalResponse_returnsTrueIfNoToolCalls() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .build(); + assertThat(event.finalResponse()).isTrue(); + } + + @Test + public void finalResponse_returnsFalseIfToolCalls() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .build(); + assertThat(event.finalResponse()).isFalse(); + } + + @Test + public void finalResponse_isTrueForEventWithTextContent() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .longRunningToolIds(ImmutableSet.of("tool1")) + .build(); + assertThat(event.finalResponse()).isTrue(); + } + + @Test + public void finalResponse_isFalseForEventWithToolCallAndLongRunningToolId() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .longRunningToolIds(ImmutableSet.of("tool1")) + .build(); + assertThat(event.finalResponse()).isFalse(); + } + + @Test + public void finalResponse_returnsTrueIfSkipSummarization() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .actions(EventActions.builder().skipSummarization(true).build()) + .build(); + assertThat(event.finalResponse()).isTrue(); + } } From 7bf55f1be6381ae5319bb0532f32c0287461546d Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 9 Feb 2026 03:25:48 -0800 Subject: [PATCH 36/63] fix: emit multiple LlmResponses in GeminiLlmConnection A single LiveServerMessage is now converted to a series of LlmResponse messages each corresponding to a different part of the LiveServerMessage, notably the UsageMetadata field is now converted to a GenerateResponseUsageMetadata and emitted downstream. PiperOrigin-RevId: 867516050 --- .../adk/models/GeminiLlmConnection.java | 139 ++++++++++++----- .../com/google/adk/models/GeminiUtil.java | 20 +++ .../adk/models/GeminiLlmConnectionTest.java | 144 ++++++++++++++++-- 3 files changed, 248 insertions(+), 55 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java b/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java index e8ae485d7..2e1229d0b 100644 --- a/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java +++ b/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java @@ -34,8 +34,11 @@ import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerToolCall; import com.google.genai.types.Part; +import com.google.genai.types.UsageMetadata; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.processors.PublishProcessor; import java.net.SocketException; import java.util.List; @@ -65,6 +68,7 @@ public final class GeminiLlmConnection implements BaseLlmConnection { private final CompletableFuture sessionFuture; private final PublishProcessor responseProcessor = PublishProcessor.create(); private final Flowable responseFlowable = responseProcessor.serialize(); + private final CompositeDisposable disposables = new CompositeDisposable(); private final AtomicBoolean closed = new AtomicBoolean(false); /** @@ -120,53 +124,104 @@ private void handleServerMessage(LiveServerMessage message) { logger.debug("Received server message: {}", message.toJson()); - Optional llmResponse = convertToServerResponse(message); - llmResponse.ifPresent(responseProcessor::onNext); + Observable llmResponse = convertToServerResponse(message); + if (!disposables.add( + llmResponse.subscribe(responseProcessor::onNext, responseProcessor::onError))) { + logger.warn( + "disposables container already disposed, the subscription will be disposed immediately"); + } } /** Converts a server message into the standardized LlmResponse format. */ - static Optional convertToServerResponse(LiveServerMessage message) { + static Observable convertToServerResponse(LiveServerMessage message) { + return Observable.create( + emitter -> { + // AtomicBoolean is used to modify state from within lambdas, which + // require captured variables to be effectively final. + final AtomicBoolean handled = new AtomicBoolean(false); + message + .serverContent() + .ifPresent( + serverContent -> { + emitter.onNext(createServerContentResponse(serverContent)); + handled.set(true); + }); + message + .toolCall() + .ifPresent( + toolCall -> { + emitter.onNext(createToolCallResponse(toolCall)); + handled.set(true); + }); + message + .usageMetadata() + .ifPresent( + usageMetadata -> { + logger.debug("Received usage metadata: {}", usageMetadata); + emitter.onNext(createUsageMetadataResponse(usageMetadata)); + handled.set(true); + }); + message + .toolCallCancellation() + .ifPresent( + toolCallCancellation -> { + logger.debug("Received tool call cancellation: {}", toolCallCancellation); + // TODO: implement proper CFC and thus tool call cancellation handling. + handled.set(true); + }); + message + .setupComplete() + .ifPresent( + setupComplete -> { + logger.debug("Received setup complete."); + handled.set(true); + }); + + if (!handled.get()) { + logger.warn("Received unknown or empty server message: {}", message.toJson()); + emitter.onNext(createUnknownMessageResponse()); + } + emitter.onComplete(); + }); + } + + private static LlmResponse createServerContentResponse(LiveServerContent serverContent) { LlmResponse.Builder builder = LlmResponse.builder(); + serverContent.modelTurn().ifPresent(builder::content); + return builder + .partial(serverContent.turnComplete().map(completed -> !completed).orElse(false)) + .turnComplete(serverContent.turnComplete().orElse(false)) + .interrupted(serverContent.interrupted()) + .build(); + } - if (message.serverContent().isPresent()) { - LiveServerContent serverContent = message.serverContent().get(); - serverContent.modelTurn().ifPresent(builder::content); - builder - .partial(serverContent.turnComplete().map(completed -> !completed).orElse(false)) - .turnComplete(serverContent.turnComplete().orElse(false)) - .interrupted(serverContent.interrupted()); - } else if (message.toolCall().isPresent()) { - LiveServerToolCall toolCall = message.toolCall().get(); - toolCall - .functionCalls() - .ifPresent( - calls -> { - for (FunctionCall call : calls) { - builder.content( - Content.builder() - .parts(ImmutableList.of(Part.builder().functionCall(call).build())) - .build()); - } - }); - builder.partial(false).turnComplete(false); - } else if (message.usageMetadata().isPresent()) { - logger.debug("Received usage metadata: {}", message.usageMetadata().get()); - return Optional.empty(); - } else if (message.toolCallCancellation().isPresent()) { - logger.debug("Received tool call cancellation: {}", message.toolCallCancellation().get()); - // TODO: implement proper CFC and thus tool call cancellation handling. - return Optional.empty(); - } else if (message.setupComplete().isPresent()) { - logger.debug("Received setup complete."); - return Optional.empty(); - } else { - logger.warn("Received unknown or empty server message: {}", message.toJson()); - builder - .errorCode(new FinishReason("Unknown server message.")) - .errorMessage("Received unknown server message."); - } + private static LlmResponse createToolCallResponse(LiveServerToolCall toolCall) { + LlmResponse.Builder builder = LlmResponse.builder(); + toolCall + .functionCalls() + .ifPresent( + calls -> { + for (FunctionCall call : calls) { + builder.content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(call).build())) + .build()); + } + }); + return builder.partial(false).turnComplete(false).build(); + } - return Optional.of(builder.build()); + private static LlmResponse createUsageMetadataResponse(UsageMetadata usageMetadata) { + return LlmResponse.builder() + .usageMetadata(GeminiUtil.toGenerateContentResponseUsageMetadata(usageMetadata)) + .build(); + } + + private static LlmResponse createUnknownMessageResponse() { + return LlmResponse.builder() + .errorCode(new FinishReason("Unknown server message.")) + .errorMessage("Received unknown server message.") + .build(); } /** Handles errors that occur *during* the initial connection attempt. */ @@ -281,6 +336,8 @@ private void closeInternal(Throwable throwable) { } else { sessionFuture.cancel(false); } + + disposables.dispose(); } } diff --git a/core/src/main/java/com/google/adk/models/GeminiUtil.java b/core/src/main/java/com/google/adk/models/GeminiUtil.java index 2b95c0ab2..319226d69 100644 --- a/core/src/main/java/com/google/adk/models/GeminiUtil.java +++ b/core/src/main/java/com/google/adk/models/GeminiUtil.java @@ -24,7 +24,9 @@ import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import com.google.genai.types.UsageMetadata; import java.util.List; import java.util.Optional; import java.util.stream.Stream; @@ -224,4 +226,22 @@ public static List stripThoughts(List originalContents) { }) .collect(toImmutableList()); } + + public static GenerateContentResponseUsageMetadata toGenerateContentResponseUsageMetadata( + UsageMetadata usageMetadata) { + GenerateContentResponseUsageMetadata.Builder builder = + GenerateContentResponseUsageMetadata.builder(); + usageMetadata.promptTokenCount().ifPresent(builder::promptTokenCount); + usageMetadata.cachedContentTokenCount().ifPresent(builder::cachedContentTokenCount); + usageMetadata.responseTokenCount().ifPresent(builder::candidatesTokenCount); + usageMetadata.toolUsePromptTokenCount().ifPresent(builder::toolUsePromptTokenCount); + usageMetadata.thoughtsTokenCount().ifPresent(builder::thoughtsTokenCount); + usageMetadata.totalTokenCount().ifPresent(builder::totalTokenCount); + usageMetadata.promptTokensDetails().ifPresent(builder::promptTokensDetails); + usageMetadata.cacheTokensDetails().ifPresent(builder::cacheTokensDetails); + usageMetadata.responseTokensDetails().ifPresent(builder::candidatesTokensDetails); + usageMetadata.toolUsePromptTokensDetails().ifPresent(builder::toolUsePromptTokensDetails); + usageMetadata.trafficType().ifPresent(builder::trafficType); + return builder.build(); + } } diff --git a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java index 5d70cc449..a3ac09fe5 100644 --- a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.LiveServerContent; import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerSetupComplete; @@ -28,6 +29,8 @@ import com.google.genai.types.LiveServerToolCallCancellation; import com.google.genai.types.Part; import com.google.genai.types.UsageMetadata; +import io.reactivex.rxjava3.observers.TestObserver; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -45,8 +48,13 @@ public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() { .build(); LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); + TestObserver testObserver = new TestObserver<>(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.content()).isPresent(); assertThat(response.content().get().text()).isEqualTo("Model response"); @@ -66,8 +74,13 @@ public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.interrupted()).hasValue(false); assertThat(response.turnComplete()).hasValue(false); } @@ -82,8 +95,13 @@ public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.interrupted()).isEmpty(); assertThat(response.turnComplete()).hasValue(true); } @@ -98,8 +116,13 @@ public void convertToServerResponse_withTurnCompleteTrue_mapsPartialFalse() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.partial()).hasValue(false); assertThat(response.turnComplete()).hasValue(true); } @@ -114,8 +137,13 @@ public void convertToServerResponse_withTurnCompleteFalse_mapsPartialTrue() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.partial()).hasValue(true); assertThat(response.turnComplete()).hasValue(false); } @@ -128,8 +156,13 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { LiveServerMessage message = LiveServerMessage.builder().toolCall(toolCall).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isPresent(); assertThat(response.content().get().parts().get()).hasSize(1); @@ -139,40 +172,123 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { } @Test - public void convertToServerResponse_withUsageMetadata_returnsEmptyOptional() { + public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageMetadata() { LiveServerMessage message = - LiveServerMessage.builder().usageMetadata(UsageMetadata.builder().build()).build(); + LiveServerMessage.builder() + .usageMetadata( + UsageMetadata.builder() + .promptTokenCount(10) + .responseTokenCount(20) + .totalTokenCount(30) + .build()) + .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); + assertThat(response.usageMetadata()).isPresent(); + GenerateContentResponseUsageMetadata expectedUsageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + assertThat(response.usageMetadata()).hasValue(expectedUsageMetadata); } @Test - public void convertToServerResponse_withToolCallCancellation_returnsEmptyOptional() { + public void convertToServerResponse_withToolCallCancellation_returnsNoValues() { LiveServerMessage message = LiveServerMessage.builder() .toolCallCancellation(LiveServerToolCallCancellation.builder().build()) .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertNoValues(); + testObserver.assertComplete(); } @Test - public void convertToServerResponse_withSetupComplete_returnsEmptyOptional() { + public void convertToServerResponse_withSetupComplete_returnsNoValues() { LiveServerMessage message = LiveServerMessage.builder() .setupComplete(LiveServerSetupComplete.builder().build()) .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertNoValues(); + testObserver.assertComplete(); } @Test public void convertToServerResponse_withUnknownMessage_returnsErrorResponse() { LiveServerMessage message = LiveServerMessage.builder().build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.errorCode()).isPresent(); assertThat(response.errorMessage()).hasValue("Received unknown server message."); } + + @Test + public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple() { + LiveServerContent serverContent = + LiveServerContent.builder() + .modelTurn(Content.fromParts(Part.fromText("Model response"))) + .turnComplete(true) + .build(); + + UsageMetadata usageMetadata = + UsageMetadata.builder() + .promptTokenCount(10) + .responseTokenCount(20) + .totalTokenCount(30) + .build(); + + LiveServerMessage message = + LiveServerMessage.builder() + .serverContent(serverContent) + .usageMetadata(usageMetadata) + .build(); + + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(2); + testObserver.assertComplete(); + + List responses = testObserver.values(); + + // Check for ServerContent response + LlmResponse contentResponse = responses.get(0); + assertThat(contentResponse.content()).isPresent(); + assertThat(contentResponse.content().get().text()).isEqualTo("Model response"); + assertThat(contentResponse.usageMetadata()).isEmpty(); + + // Check for UsageMetadata response + LlmResponse usageResponse = responses.get(1); + assertThat(usageResponse.content()).isEmpty(); + assertThat(usageResponse.usageMetadata()).isPresent(); + GenerateContentResponseUsageMetadata expectedUsageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + assertThat(usageResponse.usageMetadata()).hasValue(expectedUsageMetadata); + } } From 2fd5ac9c22324a4c430498e0f7d94f53301b39d8 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Feb 2026 01:41:15 -0800 Subject: [PATCH 37/63] ADK changes PiperOrigin-RevId: 868014417 --- .../adk/a2a/A2ASendMessageExecutor.java | 307 ------------------ a2a/webservice/pom.xml | 67 ---- .../adk/webservice/A2ARemoteApplication.java | 20 -- .../webservice/A2ARemoteConfiguration.java | 49 --- .../adk/webservice/A2ARemoteController.java | 40 --- .../adk/webservice/A2ARemoteService.java | 93 ------ contrib/samples/a2a_remote/README.md | 70 ---- contrib/samples/a2a_remote/pom.xml | 139 -------- .../a2a_remote/remote_prime_agent/Agent.java | 101 ------ .../a2a_remote/remote_prime_agent/agent.json | 17 - .../a2a_remote/RemoteA2AApplication.java | 24 -- contrib/samples/pom.xml | 1 - pom.xml | 1 - 13 files changed, 929 deletions(-) delete mode 100644 a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java delete mode 100644 a2a/webservice/pom.xml delete mode 100644 a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java delete mode 100644 a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java delete mode 100644 a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java delete mode 100644 a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java delete mode 100644 contrib/samples/a2a_remote/README.md delete mode 100644 contrib/samples/a2a_remote/pom.xml delete mode 100644 contrib/samples/a2a_remote/remote_prime_agent/Agent.java delete mode 100644 contrib/samples/a2a_remote/remote_prime_agent/agent.json delete mode 100644 contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java diff --git a/a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java b/a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java deleted file mode 100644 index bd345ab22..000000000 --- a/a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java +++ /dev/null @@ -1,307 +0,0 @@ -package com.google.adk.a2a; - -import static com.google.common.base.Strings.isNullOrEmpty; -import static java.util.concurrent.TimeUnit.MILLISECONDS; - -import com.google.adk.a2a.converters.ConversationPreprocessor; -import com.google.adk.a2a.converters.RequestConverter; -import com.google.adk.a2a.converters.ResponseConverter; -import com.google.adk.agents.BaseAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.artifacts.InMemoryArtifactService; -import com.google.adk.events.Event; -import com.google.adk.memory.InMemoryMemoryService; -import com.google.adk.runner.Runner; -import com.google.adk.sessions.InMemorySessionService; -import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import io.a2a.spec.Message; -import io.a2a.spec.TextPart; -import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Single; -import java.time.Duration; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import org.jspecify.annotations.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Shared SendMessage execution between HTTP service and other integrations. - * - *

    **EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -public final class A2ASendMessageExecutor { - private static final Logger logger = LoggerFactory.getLogger(A2ASendMessageExecutor.class); - - @FunctionalInterface - public interface AgentExecutionStrategy { - Single> execute( - String userId, - String sessionId, - Content userContent, - RunConfig runConfig, - String invocationId); - } - - private final InMemorySessionService sessionService; - private final String appName; - @Nullable private final Runner runner; - @Nullable private final Duration agentTimeout; - private static final RunConfig DEFAULT_RUN_CONFIG = - RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); - - public A2ASendMessageExecutor(InMemorySessionService sessionService, String appName) { - this.sessionService = sessionService; - this.appName = appName; - this.runner = null; - this.agentTimeout = null; - } - - public A2ASendMessageExecutor(BaseAgent agent, String appName, Duration agentTimeout) { - InMemorySessionService sessionService = new InMemorySessionService(); - Runner runnerInstance = - new Runner( - agent, - appName, - new InMemoryArtifactService(), - sessionService, - new InMemoryMemoryService()); - this.sessionService = sessionService; - this.appName = appName; - this.runner = runnerInstance; - this.agentTimeout = agentTimeout; - } - - public Single execute( - @Nullable Message request, AgentExecutionStrategy agentExecutionStrategy) { - final String invocationId = UUID.randomUUID().toString(); - final String contextId = resolveContextId(request); - final ImmutableList inputEvents = buildInputEvents(request, invocationId); - - ConversationPreprocessor.PreparedInput prepared = - ConversationPreprocessor.extractHistoryAndUserContent(inputEvents); - - String userId = buildUserId(contextId); - String sessionId = contextId; - - return ensureSessionExistsSingle(userId, sessionId, contextId) - .flatMap( - session -> - processEventsSingle( - session, prepared, userId, sessionId, invocationId, agentExecutionStrategy)) - .map( - resultEvents -> { - final String taskId = resolveTaskId(request); - return ResponseConverter.eventsToMessage(resultEvents, contextId, taskId); - }) - .onErrorReturn( - throwable -> { - logger.error("Error processing A2A request", throwable); - return errorResponse("Internal error: " + throwable.getMessage(), contextId); - }); - } - - public Single execute(@Nullable Message request) { - if (runner == null || agentTimeout == null) { - throw new IllegalStateException( - "Runner-based handle invoked without configured runner or timeout"); - } - return execute(request, this::executeAgentWithTimeout); - } - - private Single ensureSessionExistsSingle( - String userId, String sessionId, String contextId) { - return sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - ConcurrentHashMap initialState = new ConcurrentHashMap<>(); - return sessionService.createSession(appName, userId, initialState, sessionId); - })); - } - - private Completable appendHistoryEvents( - Session session, ConversationPreprocessor.PreparedInput prepared, String invocationId) { - ImmutableList eventsToAppend = - filterNewHistoryEvents(session, prepared.historyEvents, invocationId); - return appendEvents(session, eventsToAppend); - } - - private ImmutableList filterNewHistoryEvents( - Session session, List historyEvents, String invocationId) { - Set existingEventIds = new HashSet<>(); - for (Event existing : session.events()) { - if (existing.id() != null) { - existingEventIds.add(existing.id()); - } - } - - ImmutableList.Builder eventsToAppend = ImmutableList.builder(); - for (Event historyEvent : historyEvents) { - ensureIdentifiers(historyEvent, invocationId); - if (existingEventIds.add(historyEvent.id())) { - eventsToAppend.add(historyEvent); - } - } - return eventsToAppend.build(); - } - - private Completable appendEvents(Session session, ImmutableList events) { - Completable chain = Completable.complete(); - for (Event event : events) { - chain = chain.andThen(sessionService.appendEvent(session, event).ignoreElement()); - } - return chain; - } - - private Single> processEventsSingle( - Session session, - ConversationPreprocessor.PreparedInput prepared, - String userId, - String sessionId, - String invocationId, - AgentExecutionStrategy agentExecutionStrategy) { - Content userContent = - prepared.userContent.orElseGet(A2ASendMessageExecutor::defaultUserContent); - return appendHistoryEvents(session, prepared, invocationId) - .andThen( - agentExecutionStrategy.execute( - userId, sessionId, userContent, DEFAULT_RUN_CONFIG, invocationId)); - } - - private static ImmutableList defaultHelloEvent(String invocationId) { - Event e = - Event.builder() - .id(UUID.randomUUID().toString()) - .invocationId(invocationId) - .author("user") - .content(defaultUserContent()) - .build(); - return ImmutableList.of(e); - } - - private static Content defaultUserContent() { - return Content.builder() - .role("user") - .parts(ImmutableList.of(com.google.genai.types.Part.builder().text("Hello").build())) - .build(); - } - - private static Message errorResponse(String msg, String contextId) { - Message error = - new Message.Builder() - .messageId(UUID.randomUUID().toString()) - .role(Message.Role.AGENT) - .parts(ImmutableList.of(new TextPart("Error: " + msg))) - .build(); - if (contextId != null && !contextId.isEmpty()) { - error.setContextId(contextId); - } - return error; - } - - private Single> executeAgentWithTimeout( - String userId, - String sessionId, - Content userContent, - RunConfig runConfig, - String invocationId) { - if (runner == null || agentTimeout == null) { - throw new IllegalStateException("Runner-based execution invoked without configuration"); - } - - Single> agentResultSingle = - runner - .runAsync(userId, sessionId, userContent, runConfig) - .toList() - .map(events -> ImmutableList.copyOf(events)); - - return agentResultSingle - .timeout(agentTimeout.toMillis(), MILLISECONDS) - .onErrorResumeNext( - throwable -> { - if (isTimeout(throwable)) { - logger.warn( - "Agent execution exceeded {}; returning timeout event", - agentTimeout, - throwable); - return Single.just(ImmutableList.of(createTimeoutEvent(invocationId))); - } - return Single.error(throwable); - }); - } - - private static String resolveContextId(@Nullable Message inbound) { - if (inbound == null || inbound.getContextId() == null || inbound.getContextId().isEmpty()) { - return UUID.randomUUID().toString(); - } - return inbound.getContextId(); - } - - private static String resolveTaskId(@Nullable Message inbound) { - if (inbound != null && inbound.getTaskId() != null && !inbound.getTaskId().isEmpty()) { - return inbound.getTaskId(); - } - return UUID.randomUUID().toString(); - } - - private static ImmutableList buildInputEvents( - @Nullable Message inbound, String invocationId) { - if (inbound == null) { - return defaultHelloEvent(invocationId); - } - return RequestConverter.convertAggregatedA2aMessageToAdkEvents(inbound, invocationId); - } - - private static String buildUserId(String contextId) { - return "user-" + contextId; - } - - private static void ensureIdentifiers(Event event, String invocationId) { - if (isNullOrEmpty(event.id())) { - event.setId(Event.generateEventId()); - } - if (isNullOrEmpty(event.invocationId())) { - event.setInvocationId(invocationId); - } - } - - private static Event createTimeoutEvent(String invocationId) { - return Event.builder() - .id(UUID.randomUUID().toString()) - .invocationId(invocationId) - .author("agent") - .content( - Content.builder() - .role("model") - .parts( - ImmutableList.of( - com.google.genai.types.Part.builder() - .text("Agent execution timed out.") - .build())) - .build()) - .build(); - } - - private static boolean isTimeout(@Nullable Throwable throwable) { - while (throwable != null) { - if (throwable instanceof TimeoutException) { - return true; - } - if (throwable.getClass().getName().endsWith("TimeoutException")) { - return true; - } - throwable = throwable.getCause(); - } - return false; - } -} diff --git a/a2a/webservice/pom.xml b/a2a/webservice/pom.xml deleted file mode 100644 index deb03fd27..000000000 --- a/a2a/webservice/pom.xml +++ /dev/null @@ -1,67 +0,0 @@ - - - 4.0.0 - - - com.google.adk - google-adk-parent - 0.5.1-SNAPSHOT - ../../pom.xml - - - google-adk-a2a-webservice - jar - - Google ADK A2A Webservice - - - 17 - ${java.version} - - - - - com.google.adk - google-adk-a2a - ${project.version} - - - org.springframework.boot - spring-boot-starter-web - - - org.slf4j - slf4j-api - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.13.0 - - ${java.version} - - - - org.springframework.boot - spring-boot-maven-plugin - ${spring-boot.version} - - - - repackage - - - exec - - - - - - - \ No newline at end of file diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java deleted file mode 100644 index 93e321eb1..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.google.adk.webservice; - -import org.springframework.boot.SpringApplication; -import org.springframework.boot.autoconfigure.SpringBootApplication; -import org.springframework.context.annotation.Import; - -/** - * Entry point for the standalone Spring Boot A2A service. - * - *

    **EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@SpringBootApplication -@Import(A2ARemoteConfiguration.class) -public class A2ARemoteApplication { - - public static void main(String[] args) { - SpringApplication.run(A2ARemoteApplication.class, args); - } -} diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java deleted file mode 100644 index a3f9b48ac..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java +++ /dev/null @@ -1,49 +0,0 @@ -package com.google.adk.webservice; - -import com.google.adk.a2a.A2ASendMessageExecutor; -import com.google.adk.agents.BaseAgent; -import java.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.ComponentScan; -import org.springframework.context.annotation.Configuration; - -/** - * Registers the transport-only A2A webservice stack. - * - *

    Importers must supply a {@link BaseAgent} bean. The agent remains opaque to this module so the - * transport can be reused across applications. - * - *

    TODO: - * - *

      - *
    • Expose discovery endpoints (agent card / extended card) so clients can fetch metadata - * directly. - *
    • Add optional remote-proxy wiring for cases where no local agent bean is available. - *
    - * - *

    **EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@Configuration -@ComponentScan(basePackages = "com.google.adk.webservice") -public class A2ARemoteConfiguration { - - private static final Logger logger = LoggerFactory.getLogger(A2ARemoteConfiguration.class); - private static final String DEFAULT_APP_NAME = "a2a-remote-service"; - private static final long DEFAULT_TIMEOUT_SECONDS = 15L; - - @Bean - public A2ASendMessageExecutor a2aSendMessageExecutor( - BaseAgent agent, - @Value("${a2a.remote.appName:" + DEFAULT_APP_NAME + "}") String appName, - @Value("${a2a.remote.timeoutSeconds:" + DEFAULT_TIMEOUT_SECONDS + "}") long timeoutSeconds) { - logger.info( - "Initializing A2A send message executor for appName {} with timeout {}s", - appName, - timeoutSeconds); - return new A2ASendMessageExecutor(agent, appName, Duration.ofSeconds(timeoutSeconds)); - } -} diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java deleted file mode 100644 index a0fe5b0cc..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.google.adk.webservice; - -import io.a2a.spec.SendMessageRequest; -import io.a2a.spec.SendMessageResponse; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; - -/** - * REST controller exposing an A2A-compliant JSON-RPC endpoint backed by a local ADK runner. - * - *

    **EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@RestController -@RequestMapping("/a2a/remote") -public class A2ARemoteController { - - private static final Logger logger = LoggerFactory.getLogger(A2ARemoteController.class); - - private final A2ARemoteService service; - - public A2ARemoteController(A2ARemoteService service) { - this.service = service; - } - - @PostMapping( - path = "/v1/message:send", - consumes = "application/json", - produces = "application/json") - public SendMessageResponse sendMessage(@RequestBody SendMessageRequest request) { - logger.debug("Received remote A2A request: {}", request); - SendMessageResponse response = service.handle(request); - logger.debug("Responding with remote A2A payload: {}", response); - return response; - } -} diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java deleted file mode 100644 index 803774568..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java +++ /dev/null @@ -1,93 +0,0 @@ -package com.google.adk.webservice; - -import com.google.adk.a2a.A2ASendMessageExecutor; -import com.google.adk.a2a.converters.ResponseConverter; -import io.a2a.spec.JSONRPCError; -import io.a2a.spec.Message; -import io.a2a.spec.MessageSendParams; -import io.a2a.spec.SendMessageRequest; -import io.a2a.spec.SendMessageResponse; -import java.util.List; -import java.util.UUID; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.stereotype.Service; - -/** - * Core service that bridges the A2A JSON-RPC sendMessage API to a local ADK runner. - * - *

    **EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@Service -public class A2ARemoteService { - - private static final Logger logger = LoggerFactory.getLogger(A2ARemoteService.class); - private static final int ERROR_CODE_INVALID_PARAMS = -32602; - private static final int ERROR_CODE_INTERNAL_ERROR = -32603; - - private final A2ASendMessageExecutor executor; - - public A2ARemoteService(A2ASendMessageExecutor executor) { - this.executor = executor; - } - - public SendMessageResponse handle(SendMessageRequest request) { - if (request == null) { - logger.warn("Received null SendMessageRequest"); - return invalidParamsResponse(null, "Request body is missing"); - } - - MessageSendParams params = request.getParams(); - if (params == null) { - logger.warn("SendMessageRequest {} missing params", request.getId()); - return invalidParamsResponse(request, "Request params are missing"); - } - - Message inbound = params.message(); - if (inbound == null) { - logger.warn("SendMessageRequest {} missing message payload", request.getId()); - return invalidParamsResponse(request, "Request message payload is missing"); - } - - boolean generatedContext = inbound.getContextId() == null || inbound.getContextId().isEmpty(); - Message normalized = ensureContextId(inbound); - if (generatedContext) { - logger.debug("Incoming request lacked contextId; generated {}", normalized.getContextId()); - } - - try { - Message result = executor.execute(normalized).blockingGet(); - if (result == null) { - result = - ResponseConverter.eventsToMessage( - List.of(), normalized.getContextId(), normalized.getTaskId()); - } - - logger.debug("Returning A2A response for context {}", normalized.getContextId()); - return new SendMessageResponse(request.getId(), result); - } catch (RuntimeException e) { - logger.error("Failed to process remote A2A request", e); - return errorResponse(request, e); - } - } - - private static Message ensureContextId(Message message) { - if (message.getContextId() != null && !message.getContextId().isEmpty()) { - return message; - } - return new Message.Builder(message).contextId(UUID.randomUUID().toString()).build(); - } - - private static SendMessageResponse invalidParamsResponse( - SendMessageRequest request, String reason) { - JSONRPCError error = new JSONRPCError(ERROR_CODE_INVALID_PARAMS, reason, null); - return new SendMessageResponse(request != null ? request.getId() : null, error); - } - - private static SendMessageResponse errorResponse(SendMessageRequest request, Throwable error) { - String message = "Internal error processing sendMessage request"; - JSONRPCError jsonrpcError = new JSONRPCError(ERROR_CODE_INTERNAL_ERROR, message, null); - return new SendMessageResponse(request != null ? request.getId() : null, jsonrpcError); - } -} diff --git a/contrib/samples/a2a_remote/README.md b/contrib/samples/a2a_remote/README.md deleted file mode 100644 index d1d2601ca..000000000 --- a/contrib/samples/a2a_remote/README.md +++ /dev/null @@ -1,70 +0,0 @@ -# A2A Remote Prime Service Sample - -This sample starts a standalone Spring Boot service that exposes the -`remote_prime_agent` via the shared A2A webservice module -(`google-adk-a2a-webservice`). It behaves like a third‑party service that -implements the A2A JSON‑RPC contract and can be used by the ADK client (for -example, the `a2a_basic` demo) as its remote endpoint. - -## Running the service - -```bash -cd google_adk -mvn -f contrib/samples/a2a_remote/pom.xml package - -GOOGLE_GENAI_USE_VERTEXAI=FALSE \ -GOOGLE_API_KEY= \ -mvn -f contrib/samples/a2a_remote/pom.xml exec:java -``` - -`RemoteA2AApplication` imports the reusable controller/service from -`google-adk-a2a-webservice`, so the server listens on -`http://localhost:8080/a2a/remote/v1/message:send` by default. Override the -port with `-Dspring-boot.run.arguments=--server.port=` when running via -`spring-boot:run` if you need to avoid collisions. - -``` -POST /a2a/remote/v1/message:send -Content-Type: application/json -``` - -and accepts standard A2A JSON‑RPC payloads (`SendMessageRequest`). The -response is a `SendMessageResponse` that contains either a `Message` or a -`Task` in the `result` field. Spring Boot logs the request/response lifecycle -to the console; add your preferred logging configuration if you need -persistent logs. - -## Agent implementation - -- `remote_prime_agent/Agent.java` hosts the LLM agent that checks whether - numbers are prime (lifted from the Stubby demo). The model name defaults - to `gemini-2.5-pro`; set `GOOGLE_API_KEY` before running. -- `RemoteA2AApplication` bootstraps the service by importing - `A2ARemoteConfiguration` and publishing the prime `BaseAgent` bean. The shared - configuration consumes that bean to create the `A2ASendMessageExecutor`. - -## Sample request - -```bash -curl -X POST http://localhost:8080/a2a/remote/v1/message:send \ - -H 'Content-Type: application/json' \ - -d '{ - "jsonrpc": "2.0", - "id": "demo-123", - "method": "message/send", - "params": { - "message": { - "role": "user", - "messageId": "msg-1", - "contextId": "ctx-1", - "parts": [ - {"kind": "text", "text": "Check if 17 is prime"} - ] - }, - "metadata": {} - } - }' -``` - -The response contains the prime check result, and the interaction is logged in -the application console. diff --git a/contrib/samples/a2a_remote/pom.xml b/contrib/samples/a2a_remote/pom.xml deleted file mode 100644 index 59d9cf01e..000000000 --- a/contrib/samples/a2a_remote/pom.xml +++ /dev/null @@ -1,139 +0,0 @@ - - - 4.0.0 - - - com.google.adk - google-adk-parent - 0.5.1-SNAPSHOT - ../../../pom.xml - - - google-adk-sample-a2a-remote - Google ADK - Sample - A2A Remote Prime Service - Spring Boot service that exposes the remote prime-check agent over the A2A REST interface. - jar - - - 3.3.4 - 17 - 0.8 - com.google.adk.samples.a2a_remote.RemoteA2AApplication - - - - - - org.springframework.boot - spring-boot-dependencies - ${spring-boot.version} - pom - import - - - - - - - org.springframework.boot - spring-boot-starter-web - - - - com.google.adk - google-adk - ${project.version} - - - - com.google.adk - google-adk-a2a - ${project.version} - - - - com.google.adk - google-adk-a2a-webservice - ${project.version} - - - - com.google.flogger - flogger - ${flogger.version} - - - com.google.flogger - google-extensions - ${flogger.version} - - - com.google.flogger - flogger-system-backend - ${flogger.version} - - - - org.springframework.boot - spring-boot-starter-test - test - - - - com.google.truth - truth - ${truth.version} - test - - - - - - - org.springframework.boot - spring-boot-maven-plugin - ${spring-boot.version} - - - org.codehaus.mojo - build-helper-maven-plugin - 3.6.0 - - - add-source - generate-sources - - add-source - - - - . - - - - - - - org.apache.maven.plugins - maven-source-plugin - - - **/*.jar - target/** - - - - - org.codehaus.mojo - exec-maven-plugin - 3.2.0 - - ${exec.mainClass} - runtime - - - - - \ No newline at end of file diff --git a/contrib/samples/a2a_remote/remote_prime_agent/Agent.java b/contrib/samples/a2a_remote/remote_prime_agent/Agent.java deleted file mode 100644 index a0072e8e3..000000000 --- a/contrib/samples/a2a_remote/remote_prime_agent/Agent.java +++ /dev/null @@ -1,101 +0,0 @@ -package com.google.adk.samples.a2a_remote.remote_prime_agent; - -import static java.util.stream.Collectors.joining; - -import com.google.adk.agents.LlmAgent; -import com.google.adk.tools.FunctionTool; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.flogger.GoogleLogger; -import io.reactivex.rxjava3.core.Maybe; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -/** Agent that can check whether numbers are prime. */ -public final class Agent { - - private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); - - public static ImmutableMap checkPrime(List nums) { - logger.atInfo().log("checkPrime called with nums=%s", nums); - Set primes = new HashSet<>(); - for (int num : nums) { - if (num <= 1) { - continue; - } - boolean isPrime = true; - for (int i = 2; i <= Math.sqrt(num); i++) { - if (num % i == 0) { - isPrime = false; - break; - } - } - if (isPrime) { - primes.add(num); - } - } - String result; - if (primes.isEmpty()) { - result = "No prime numbers found."; - } else if (primes.size() == 1) { - int only = primes.iterator().next(); - // Per request: singular phrasing without article - result = only + " is prime number."; - } else { - result = primes.stream().map(String::valueOf).collect(joining(", ")) + " are prime numbers."; - } - logger.atInfo().log("checkPrime result=%s", result); - return ImmutableMap.of("result", result); - } - - public static final LlmAgent ROOT_AGENT = - LlmAgent.builder() - .model("gemini-2.5-pro") - .name("check_prime_agent") - .description("check prime agent that can check whether numbers are prime.") - .instruction( - """ - You check whether numbers are prime. - - If the last user message contains numbers, call checkPrime exactly once with exactly - those integers as a list (e.g., [2]). Never add other numbers. Do not ask for - clarification. Return only the tool's result. - - Always pass a list of integers to the tool (use a single-element list for one - number). Never pass strings. - """) - // Log the exact contents passed to the LLM request for verification - .beforeModelCallback( - (callbackContext, llmRequest) -> { - try { - logger.atInfo().log( - "Invocation events (count=%d): %s", - callbackContext.events().size(), callbackContext.events()); - } catch (Throwable t) { - logger.atWarning().withCause(t).log("BeforeModel logging error"); - } - return Maybe.empty(); - }) - .afterModelCallback( - (callbackContext, llmResponse) -> { - try { - String content = - llmResponse.content().map(Object::toString).orElse(""); - logger.atInfo().log("AfterModel content=%s", content); - llmResponse - .errorMessage() - .ifPresent( - error -> - logger.atInfo().log( - "AfterModel errorMessage=%s", error.replace("\n", "\\n"))); - } catch (Throwable t) { - logger.atWarning().withCause(t).log("AfterModel logging error"); - } - return Maybe.empty(); - }) - .tools(ImmutableList.of(FunctionTool.create(Agent.class, "checkPrime"))) - .build(); - - private Agent() {} -} diff --git a/contrib/samples/a2a_remote/remote_prime_agent/agent.json b/contrib/samples/a2a_remote/remote_prime_agent/agent.json deleted file mode 100644 index 87f2d9ecc..000000000 --- a/contrib/samples/a2a_remote/remote_prime_agent/agent.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "capabilities": {}, - "defaultInputModes": ["text/plain"], - "defaultOutputModes": ["application/json"], - "description": "An agent specialized in checking whether numbers are prime. It can efficiently determine the primality of individual numbers or lists of numbers.", - "name": "check_prime_agent", - "skills": [ - { - "id": "prime_checking", - "name": "Prime Number Checking", - "description": "Check if numbers in a list are prime using efficient mathematical algorithms", - "tags": ["mathematical", "computation", "prime", "numbers"] - } - ], - "url": "http://localhost:8080/a2a/prime_agent", - "version": "1.0.0" -} diff --git a/contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java b/contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java deleted file mode 100644 index 53be8d1d0..000000000 --- a/contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.google.adk.samples.a2a_remote; - -import com.google.adk.agents.BaseAgent; -import com.google.adk.samples.a2a_remote.remote_prime_agent.Agent; -import com.google.adk.webservice.A2ARemoteConfiguration; -import org.springframework.boot.SpringApplication; -import org.springframework.boot.autoconfigure.SpringBootApplication; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Import; - -/** Spring Boot entry point that wires the shared A2A webservice with the prime demo agent. */ -@SpringBootApplication -@Import(A2ARemoteConfiguration.class) -public class RemoteA2AApplication { - - public static void main(String[] args) { - SpringApplication.run(RemoteA2AApplication.class, args); - } - - @Bean - public BaseAgent primeAgent() { - return Agent.ROOT_AGENT; - } -} diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index fa5d6dfae..580b10de9 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -17,7 +17,6 @@ a2a_basic - a2a_remote configagent helloworld mcpfilesystem diff --git a/pom.xml b/pom.xml index 6a1aa5af5..89f0d2c0f 100644 --- a/pom.xml +++ b/pom.xml @@ -37,7 +37,6 @@ tutorials/city-time-weather tutorials/live-audio-single-agent a2a - a2a/webservice From 67c29e3a33bda22d8a18a17c99e5abc891bf19f8 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Feb 2026 02:53:44 -0800 Subject: [PATCH 38/63] fix: Merging of events in rearrangeEventsForAsyncFunctionResponsesInHistory reinstating the fix in #682 which was later reverted in #827 PiperOrigin-RevId: 868043103 --- .../google/adk/flows/llmflows/Contents.java | 3 +- .../adk/flows/llmflows/ContentsTest.java | 72 +++++++++++++++++-- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 171dab972..f45461626 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -564,8 +564,7 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int i = 0; i < events.size(); i++) { Event event = events.get(i); - // Skip response events that will be processed via responseEventsBuffer - if (processedResponseIndices.contains(i)) { + if (!event.functionResponses().isEmpty()) { continue; } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 82a57ed4f..a8a862b51 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -203,9 +203,11 @@ public void rearrangeHistory_asyncFR_returnsRearrangedList() { public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { Event fcEvent = createFunctionCallEvent("fc1", "tool1", "call1"); Event frEvent1 = - createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "running")); + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); Event frEvent2 = - createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "running")); + Event frEvent3 = + createFunctionResponseEvent("fr3", "tool1", "call1", ImmutableMap.of("status", "done")); ImmutableList inputEvents = ImmutableList.of( createUserEvent("u1", "Query"), @@ -213,17 +215,75 @@ public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { createUserEvent("u2", "Wait"), frEvent1, createUserEvent("u3", "Done?"), - frEvent2); + frEvent2, + frEvent3, + createUserEvent("u4", "Follow up query")); List result = runContentsProcessor(inputEvents); - assertThat(result).hasSize(3); // u1, fc1, merged_fr + assertThat(result).hasSize(6); // u1, fc1, merged_fr, u2, u3, u4 assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); - assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check merged event + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check fcEvent Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(1); assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) - .containsExactly("status", "done"); // Last FR wins + .containsExactly("status", "done"); // Last FR wins (frEvent3) + assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 + assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 + assertThat(result.get(5)).isEqualTo(inputEvents.get(7).content().get()); // u4 + } + + @Test + public void rearrangeHistory_multipleFRsForMultipleFC_returnsMergedFR() { + Event fcEvent1 = createFunctionCallEvent("fc1", "tool1", "call1"); + Event fcEvent2 = createFunctionCallEvent("fc2", "tool1", "call2"); + + Event frEvent1 = + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); + Event frEvent2 = + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); + + Event frEvent3 = + createFunctionResponseEvent("fr3", "tool1", "call2", ImmutableMap.of("status", "pending")); + Event frEvent4 = + createFunctionResponseEvent("fr4", "tool1", "call2", ImmutableMap.of("status", "done")); + + ImmutableList inputEvents = + ImmutableList.of( + createUserEvent("u1", "I"), + fcEvent1, + createUserEvent("u2", "am"), + frEvent1, + createUserEvent("u3", "waiting"), + frEvent2, + createUserEvent("u4", "for"), + fcEvent2, + createUserEvent("u5", "you"), + frEvent3, + createUserEvent("u6", "to"), + frEvent4, + createUserEvent("u7", "Follow up query")); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).hasSize(11); // u1, fc1, frEvent2, u2, u3, u4, fc2, frEvent4, u5, u6, u7 + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); // u1 + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // fc1 + Content mergedContent = result.get(2); + assertThat(mergedContent.parts().get()).hasSize(1); + assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) + .containsExactly("status", "done"); // Last FR wins (frEvent2) + assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 + assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 + assertThat(result.get(5)).isEqualTo(inputEvents.get(6).content().get()); // u4 + assertThat(result.get(6)).isEqualTo(inputEvents.get(7).content().get()); // fc2 + Content mergedContent2 = result.get(7); + assertThat(mergedContent2.parts().get()).hasSize(1); + assertThat(mergedContent2.parts().get().get(0).functionResponse().get().response().get()) + .containsExactly("status", "done"); // Last FR wins (frEvent4) + assertThat(result.get(8)).isEqualTo(inputEvents.get(8).content().get()); // u5 + assertThat(result.get(9)).isEqualTo(inputEvents.get(10).content().get()); // u6 + assertThat(result.get(10)).isEqualTo(inputEvents.get(12).content().get()); // u7 } @Test From 495bf95642b9159aa6040868fcaa97fed166035b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Feb 2026 08:39:03 -0800 Subject: [PATCH 39/63] feat: Adding a .close() method to Runner, Agent and Plugins PiperOrigin-RevId: 868164137 --- .../java/com/google/adk/agents/BaseAgent.java | 13 +++++ .../java/com/google/adk/agents/LlmAgent.java | 21 +++++++ .../java/com/google/adk/plugins/Plugin.java | 10 ++++ .../com/google/adk/plugins/PluginManager.java | 13 +++++ .../java/com/google/adk/runner/Runner.java | 8 +++ .../com/google/adk/agents/BaseAgentTest.java | 51 +++++++++++++++++ .../com/google/adk/agents/LlmAgentTest.java | 55 +++++++++++++++++++ .../google/adk/plugins/PluginManagerTest.java | 27 +++++++++ .../com/google/adk/runner/RunnerTest.java | 20 +++++++ .../com/google/adk/testing/TestBaseAgent.java | 2 +- 10 files changed, 219 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 0db4dabb5..0a7a09864 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -32,9 +32,11 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -108,6 +110,17 @@ public BaseAgent( } } + /** + * Closes all sub-agents. + * + * @return a {@link Completable} that completes when all sub-agents are closed. + */ + public Completable close() { + List completables = new ArrayList<>(); + this.subAgents.forEach(subAgent -> completables.add(subAgent.close())); + return Completable.mergeDelayError(completables); + } + /** * Validates the agent name. * diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 87967bb6d..1893fb162 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -62,6 +62,7 @@ import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -1055,6 +1056,26 @@ public static LlmAgent fromConfig(LlmAgentConfig config, String configAbsPath) return agent; } + @Override + public Completable close() { + List completables = new ArrayList<>(); + toolsets() + .forEach( + toolset -> + completables.add( + Completable.fromAction( + () -> { + try { + toolset.close(); + } catch (Exception e) { + logger.error("Failed to close toolset", e); + throw e; + } + }))); + completables.add(super.close()); + return Completable.mergeDelayError(completables); + } + private static void setCallbacksFromConfig(LlmAgentConfig config, Builder builder) throws ConfigurationException { ConfigAgentUtils.resolveAndSetCallback( diff --git a/core/src/main/java/com/google/adk/plugins/Plugin.java b/core/src/main/java/com/google/adk/plugins/Plugin.java index 97a9038d4..c9cda5680 100644 --- a/core/src/main/java/com/google/adk/plugins/Plugin.java +++ b/core/src/main/java/com/google/adk/plugins/Plugin.java @@ -87,6 +87,16 @@ default Completable afterRunCallback(InvocationContext invocationContext) { return Completable.complete(); } + /** + * Method executed when the runner is closed. + * + *

    This method is used for cleanup tasks such as closing network connections or releasing + * resources. + */ + default Completable close() { + return Completable.complete(); + } + /** * Callback executed before an agent's primary logic is invoked. * diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index d7ce6b819..a63d9a402 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -125,6 +125,19 @@ public Completable afterRunCallback(InvocationContext invocationContext) { e))); } + @Override + public Completable close() { + return Flowable.fromIterable(plugins) + .concatMapCompletableDelayError( + plugin -> + plugin + .close() + .doOnError( + e -> + logger.error( + "[{}] Error during callback 'close'", plugin.getName(), e))); + } + public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { return onEventCallback(invocationContext, event); } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index e543f7d69..09b9752a4 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -324,6 +324,14 @@ public PluginManager pluginManager() { return this.pluginManager; } + /** Closes all plugins, code executors, and releases any resources. */ + public Completable close() { + List completables = new ArrayList<>(); + completables.add(agent.close()); + completables.add(this.pluginManager.close()); + return Completable.mergeDelayError(completables); + } + /** * Appends a new user message to the session history with optional state delta. * diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index d435e90c3..bf68f905c 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -28,8 +28,10 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -40,6 +42,20 @@ public final class BaseAgentTest { private static final String TEST_AGENT_NAME = "testAgent"; private static final String TEST_AGENT_DESCRIPTION = "A test agent"; + private static class ClosableTestAgent extends TestBaseAgent { + final AtomicBoolean closed = new AtomicBoolean(false); + + ClosableTestAgent(String name, String description, List subAgents) { + super(name, description, null, subAgents, null, null); + } + + @Override + public Completable close() { + closed.set(true); + return super.close(); + } + } + @Test public void constructor_setsNameAndDescription() { String name = "testName"; @@ -362,4 +378,39 @@ public void constructor_duplicateSubAgentNames_throwsIllegalArgumentException() new TestBaseAgent( "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null)); } + + @Test + public void close_noSubAgents_completesSuccessfully() { + ClosableTestAgent agent = new ClosableTestAgent("agent", "description", ImmutableList.of()); + agent.close().blockingAwait(); + assertThat(agent.closed.get()).isTrue(); + } + + @Test + public void close_oneLevelSubAgents_closesAllSubAgents() { + ClosableTestAgent subAgent1 = new ClosableTestAgent("sub1", "sub1", ImmutableList.of()); + ClosableTestAgent subAgent2 = new ClosableTestAgent("sub2", "sub2", ImmutableList.of()); + ClosableTestAgent agent = + new ClosableTestAgent("agent", "description", ImmutableList.of(subAgent1, subAgent2)); + + agent.close().blockingAwait(); + + assertThat(agent.closed.get()).isTrue(); + assertThat(subAgent1.closed.get()).isTrue(); + assertThat(subAgent2.closed.get()).isTrue(); + } + + @Test + public void close_twoLevelsSubAgents_closesAllSubAgents() { + ClosableTestAgent subSubAgent = new ClosableTestAgent("subSub", "subSub", ImmutableList.of()); + ClosableTestAgent subAgent = new ClosableTestAgent("sub", "sub", ImmutableList.of(subSubAgent)); + ClosableTestAgent agent = + new ClosableTestAgent("agent", "description", ImmutableList.of(subAgent)); + + agent.close().blockingAwait(); + + assertThat(agent.closed.get()).isTrue(); + assertThat(subAgent.closed.get()).isTrue(); + assertThat(subSubAgent.closed.get()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index ae50b5b8e..760f67c7b 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -42,17 +42,20 @@ import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -61,6 +64,20 @@ @RunWith(JUnit4.class) public final class LlmAgentTest { + private static class ClosableToolset implements BaseToolset { + final AtomicBoolean closed = new AtomicBoolean(false); + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public void close() { + closed.set(true); + } + } + @Test public void testRun_withNoCallbacks() { Content modelContent = Content.fromParts(Part.fromText("Real LLM response")); @@ -441,4 +458,42 @@ public void run_sequentialAgents_shareTempStateViaSession() { assertThat(request2.getFirstSystemInstruction().get()) .contains("Instruction for Agent2 based on Agent1 output: value1"); } + + @Test + public void close_closesToolsets() throws Exception { + ClosableToolset toolset1 = new ClosableToolset(); + ClosableToolset toolset2 = new ClosableToolset(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .tools(toolset1, toolset2) + .build(); + agent.close().blockingAwait(); + assertThat(toolset1.closed.get()).isTrue(); + assertThat(toolset2.closed.get()).isTrue(); + } + + @Test + public void close_closesToolsetsOnException() throws Exception { + ClosableToolset toolset1 = + new ClosableToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public void close() { + super.close(); + throw new RuntimeException("toolset1 failed to close"); + } + }; + ClosableToolset toolset2 = new ClosableToolset(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .tools(toolset1, toolset2) + .build(); + agent.close().test().assertError(RuntimeException.class); + assertThat(toolset1.closed.get()).isTrue(); + assertThat(toolset2.closed.get()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index ae42bb27e..4ae856fc7 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -333,4 +333,31 @@ public void onToolErrorCallback_singlePlugin() { verify(plugin1).onToolErrorCallback(mockTool, toolArgs, mockToolContext, mockThrowable); } + + @Test + public void close_allComplete() { + when(plugin1.close()).thenReturn(Completable.complete()); + when(plugin2.close()).thenReturn(Completable.complete()); + pluginManager.registerPlugin(plugin1); + pluginManager.registerPlugin(plugin2); + + pluginManager.close().test().assertResult(); + + verify(plugin1).close(); + verify(plugin2).close(); + } + + @Test + public void close_plugin1Fails() { + RuntimeException testException = new RuntimeException("Test"); + when(plugin1.close()).thenReturn(Completable.error(testException)); + when(plugin2.close()).thenReturn(Completable.complete()); + pluginManager.registerPlugin(plugin1); + pluginManager.registerPlugin(plugin2); + + pluginManager.close().test().assertError(testException); + + verify(plugin1).close(); + verify(plugin2).close(); + } } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index a01f4201e..86b0a81ec 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -1108,6 +1108,26 @@ public void runAsync_withToolConfirmation() { .inOrder(); } + @Test + public void close_closesPluginsAndCodeExecutors() { + BasePlugin plugin = mockPlugin("close_test_plugin"); + when(plugin.close()).thenReturn(Completable.complete()); + LlmAgent agentWithCodeExecutor = createTestAgentBuilder(testLlm).build(); + Runner runner = + Runner.builder() + .app( + App.builder() + .name("test") + .rootAgent(agentWithCodeExecutor) + .plugins(ImmutableList.of(plugin)) + .build()) + .build(); + + runner.close().blockingAwait(); + + verify(plugin).close(); + } + public static class Tools { private Tools() {} diff --git a/core/src/test/java/com/google/adk/testing/TestBaseAgent.java b/core/src/test/java/com/google/adk/testing/TestBaseAgent.java index e3e5a632c..001993a59 100644 --- a/core/src/test/java/com/google/adk/testing/TestBaseAgent.java +++ b/core/src/test/java/com/google/adk/testing/TestBaseAgent.java @@ -26,7 +26,7 @@ import java.util.function.Supplier; /** A test agent that returns events from a supplier. */ -public final class TestBaseAgent extends BaseAgent { +public class TestBaseAgent extends BaseAgent { private final Supplier> eventSupplier; private int invocationCount = 0; private InvocationContext lastInvocationContext; From be35b2277e8291336013623cb9f0c86f62ed1f43 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Tue, 10 Feb 2026 08:53:35 -0800 Subject: [PATCH 40/63] fix: javadocs in ResponseConverter PiperOrigin-RevId: 868170491 --- .../java/com/google/adk/a2a/converters/ResponseConverter.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 785ce6f38..61ab84c90 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -163,7 +163,7 @@ public static Message eventToMessage(Event event, String contextId) { * empty optional if the event should be ignored (e.g. if the event is not a final update for * TaskArtifactUpdateEvent or if the message is empty for TaskStatusUpdateEvent). * - * @throws an {@link IllegalArgumentException} if the event type is not supported. + * @throws IllegalArgumentException if the event type is not supported. */ public static Optional clientEventToEvent( ClientEvent event, InvocationContext invocationContext) { @@ -182,7 +182,7 @@ public static Optional clientEventToEvent( * the event is not a final update for TaskArtifactUpdateEvent or if the message is empty for * TaskStatusUpdateEvent. * - * @throws an {@link IllegalArgumentException} if the task update type is not supported. + * @throws IllegalArgumentException if the task update type is not supported. */ private static Optional handleTaskUpdate( TaskUpdateEvent event, InvocationContext context) { From 3338565cff976fdad1eda1fccafef58c9d4a51ba Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Feb 2026 10:52:52 -0800 Subject: [PATCH 41/63] feat: Token count estimation fallback for tail retention compaction PiperOrigin-RevId: 868224060 --- .../TailRetentionEventCompactor.java | 75 +++++++++++-------- .../TailRetentionEventCompactorTest.java | 38 +++++++++- 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java index b084de860..e193e7686 100644 --- a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java +++ b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java @@ -73,37 +73,12 @@ public Completable compact(Session session, BaseSessionService sessionService) { logger.debug("Running tail retention event compaction for session {}", session.id()); return Maybe.just(session.events()) - .filter(this::shouldCompact) - .flatMap(events -> getCompactionEvents(events)) + .flatMap(this::getCompactionEvents) .flatMap(summarizer::summarizeEvents) .flatMapSingle(e -> sessionService.appendEvent(session, e)) .ignoreElement(); } - private boolean shouldCompact(List events) { - int count = getLatestPromptTokenCount(events).orElse(0); - - // TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not - // available. - if (count <= tokenThreshold) { - logger.debug( - "Skipping compaction. Prompt token count {} is within threshold {}", - count, - tokenThreshold); - return false; - } - return true; - } - - private Optional getLatestPromptTokenCount(List events) { - return Lists.reverse(events).stream() - .map(Event::usageMetadata) - .flatMap(Optional::stream) - .map(GenerateContentResponseUsageMetadata::promptTokenCount) - .flatMap(Optional::stream) - .findFirst(); - } - /** * Identifies events to be compacted based on the tail retention strategy. * @@ -161,8 +136,19 @@ private Optional getLatestPromptTokenCount(List events) { * together. The new compaction event will cover the range from the start of the included * compaction event (C2, T=1) to the end of the new events (E4, T=4). * + * + * @param events The list of events to process. */ private Maybe> getCompactionEvents(List events) { + Optional count = getLatestPromptTokenCount(events); + if (count.isPresent() && count.get() <= tokenThreshold) { + logger.debug( + "Skipping compaction. Prompt token count {} is within threshold {}", + count.get(), + tokenThreshold); + return Maybe.empty(); + } + long compactionEndTimestamp = Long.MIN_VALUE; Event lastCompactionEvent = null; List eventsToSummarize = new ArrayList<>(); @@ -195,11 +181,6 @@ private Maybe> getCompactionEvents(List events) { } } - // If there are not enough events to summarize, we can return early. - if (eventsToSummarize.size() <= retentionSize) { - return Maybe.empty(); - } - // Add the last compaction event to the list of events to summarize. // This is to ensure that the last compaction event is included in the summary. if (lastCompactionEvent != null) { @@ -214,6 +195,22 @@ private Maybe> getCompactionEvents(List events) { Collections.reverse(eventsToSummarize); + if (count.isEmpty()) { + int estimatedCount = estimateTokenCount(eventsToSummarize); + if (estimatedCount <= tokenThreshold) { + logger.debug( + "Skipping compaction. Estimated prompt token count {} is within threshold {}", + estimatedCount, + tokenThreshold); + return Maybe.empty(); + } + } + + // If there are not enough events to summarize, we can return early. + if (eventsToSummarize.size() <= retentionSize) { + return Maybe.empty(); + } + // Apply retention: keep the most recent 'retentionSize' events out of the summary. // We do this by removing them from the list of events to be summarized. eventsToSummarize @@ -222,6 +219,22 @@ private Maybe> getCompactionEvents(List events) { return Maybe.just(eventsToSummarize); } + private int estimateTokenCount(List events) { + // A common rule of thumb is that one token roughly corresponds to 4 characters of text for + // common English text. + // See https://platform.openai.com/tokenizer + return events.stream().mapToInt(event -> event.stringifyContent().length()).sum() / 4; + } + + private Optional getLatestPromptTokenCount(List events) { + return Lists.reverse(events).stream() + .map(Event::usageMetadata) + .flatMap(Optional::stream) + .map(GenerateContentResponseUsageMetadata::promptTokenCount) + .flatMap(Optional::stream) + .findFirst(); + } + private static boolean isCompactEvent(Event event) { return event.actions() != null && event.actions().compaction().isPresent(); } diff --git a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java index 3260fbe1e..7a4a3ddb9 100644 --- a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java +++ b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java @@ -75,9 +75,13 @@ public void constructor_negativeRetentionSize_throwsException() { } @Test - // TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is - // not available. - public void compaction_skippedWhenTokenUsageMissing() { + public void compaction_skippedWhenEstimatedTokenUsageBelowThreshold() { + // Threshold is 100. + // Event1: "Event1" -> length 6. + // Retain1: "Retain1" -> length 7. + // Retain2: "Retain2" -> length 7. + // Total length = 20. Estimated tokens = 20 / 4 = 5. + // 5 <= 100 -> Skip. EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); ImmutableList events = ImmutableList.of( @@ -92,6 +96,34 @@ public void compaction_skippedWhenTokenUsageMissing() { verify(mockSessionService, never()).appendEvent(any(), any()); } + @Test + public void compaction_happensWhenEstimatedTokenUsageAboveThreshold() { + // Threshold is 2. + // Event1: "Event1" -> length 6. + // Retain1: "Retain1" -> length 7. + // Retain2: "Retain2" -> length 7. + // Total eligible for estimation (including retained ones as per current logic): + // Logic: getCompactionEvents returns [Event1, Retain1, Retain2] for estimation. + // Total length = 20. Estimated tokens = 20 / 4 = 5. + // 5 > 2 -> Compact. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 2); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + createEvent(3, "Retain2")); // No usage metadata + Session session = Session.builder("id").events(events).build(); + Event summaryEvent = createEvent(4, "Summary"); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer).summarizeEvents(any()); + verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent)); + } + @Test public void compaction_skippedWhenTokenUsageBelowThreshold() { // Threshold is 300, usage is 200. From ee459b3198d19972744514d1e74f076ee2bd32a7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Feb 2026 11:15:07 -0800 Subject: [PATCH 42/63] feat: Add Compact processor to SingleFlow PiperOrigin-RevId: 868235383 --- .../google/adk/flows/llmflows/SingleFlow.java | 1 + .../adk/flows/llmflows/SingleFlowTest.java | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index cc2fb443a..de45ba702 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -30,6 +30,7 @@ public class SingleFlow extends BaseLlmFlow { new RequestConfirmationLlmRequestProcessor(), new Instructions(), new Identity(), + new Compaction(), new Contents(), new Examples(), CodeExecution.requestProcessor); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java new file mode 100644 index 000000000..ccb10a3a7 --- /dev/null +++ b/core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SingleFlowTest { + + @Test + public void requestProcessors_containsCompaction() { + boolean hasCompaction = + SingleFlow.REQUEST_PROCESSORS.stream() + .anyMatch(processor -> processor instanceof Compaction); + assertThat(hasCompaction).isTrue(); + } +} From 8acb1eafb099723dfae065d8b9339bb5180aa26f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Feb 2026 13:43:04 -0800 Subject: [PATCH 43/63] feat: Updating the Tracing implementation and updating BaseAgent.runLive This update should make the java tracing consistent with Python ADK: 1. **`BaseAgent.java`**: * The `runAsync` and `runLive` methods have been modified to create the `InvocationContext` before starting the tracing span. * The span name for agent invocations has been changed from `"agent_run []"` to `"invoke_agent "`. * A new `Tracing.traceAgentInvocation` method is now called to add initial tracing attributes for the agent invocation. * In `runLive`, the `runLiveImpl` execution is now wrapped with calls to `beforeAgentCallback` and `afterAgentCallback` to ensure proper tracing of these lifecycle events. 2. **`Tracing.java`**: * The OpenTelemetry tracer name has been updated from `"com.google.adk"` to `"gcp.vertex.agent"`. * A new `traceAgentInvocation` method has been added to set standard attributes for agent invocation spans, including `gen_ai.operation.name`, `gen_ai.agent.description`, `gen_ai.agent.name`, and `gen_ai.conversation.id`. * Attribute keys used in `traceToolCall`, `traceToolResponse`, `traceCallLlm`, and `traceSendData` have been updated to use the `"gcp.vertex.agent."` prefix instead of `"adk."` or `"com.google.adk"`. * The serialization of message content (like tool call arguments, tool responses, and data) is now guarded by the `CAPTURE_MESSAGE_CONTENT_IN_SPANS` flag. When disabled, empty JSON objects are recorded instead. * `traceToolResponse` now includes logic to extract and trace the `tool_call.id` and the tool response content from `FunctionResponse` objects. * `traceCallLlm` now captures additional LLM request and response details, such as `gen_ai.request.top_p`, `gen_ai.request.max_tokens`, `gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens`, and `gen_ai.response.finish_reasons`. PiperOrigin-RevId: 868299966 --- .../java/com/google/adk/agents/BaseAgent.java | 46 +- .../google/adk/flows/llmflows/Functions.java | 7 +- .../com/google/adk/telemetry/Tracing.java | 229 ++++++++-- .../com/google/adk/agents/BaseAgentTest.java | 214 ++++++++++ .../adk/telemetry/ContextPropagationTest.java | 395 +++++++++++++----- .../com/google/adk/testing/TestCallback.java | 7 + 6 files changed, 744 insertions(+), 154 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 0a7a09864..20d7dfa4f 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -304,15 +304,12 @@ public Flowable runAsync(InvocationContext parentContext) { Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); Span span = - tracer - .spanBuilder("agent_run [" + name() + "]") - .setParent(Context.current()) - .startSpan(); + tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); + Tracing.traceAgentInvocation(span, name(), description(), invocationContext); Context spanContext = Context.current().with(span); - InvocationContext invocationContext = createInvocationContext(parentContext); - return Tracing.traceFlowable( spanContext, span, @@ -443,16 +440,41 @@ public Flowable runLive(InvocationContext parentContext) { Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); Span span = - tracer - .spanBuilder("agent_run [" + name() + "]") - .setParent(Context.current()) - .startSpan(); + tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); + Tracing.traceAgentInvocation(span, name(), description(), invocationContext); Context spanContext = Context.current().with(span); - InvocationContext invocationContext = createInvocationContext(parentContext); + return Tracing.traceFlowable( + spanContext, + span, + () -> + callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEventOpt -> { + if (invocationContext.endInvocation()) { + return Flowable.fromOptional(beforeEventOpt); + } + + Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); + Flowable mainEvents = + Flowable.defer(() -> runLiveImpl(invocationContext)); + Flowable afterEvents = + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), + afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::fromOptional)); - return Tracing.traceFlowable(spanContext, span, () -> runLiveImpl(invocationContext)); + return Flowable.concat(beforeEvents, mainEvents, afterEvents); + })); }); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index a6fb74d88..5a855445a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -181,7 +181,7 @@ public static Maybe handleFunctionCalls( Span mergedSpan = tracer.spanBuilder("tool_response").setParent(Context.current()).startSpan(); try (Scope scope = mergedSpan.makeCurrent()) { - Tracing.traceToolResponse(invocationContext, mergedEvent.id(), mergedEvent); + Tracing.traceToolResponse(mergedEvent.id(), mergedEvent); } finally { mergedSpan.end(); } @@ -571,7 +571,8 @@ private static Maybe> callTool( .setParent(parentContext) .startSpan(); try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall(args); + Tracing.traceToolCall( + tool.name(), tool.description(), tool.getClass().getSimpleName(), args); return tool.runAsync(args, toolContext) .toMaybe() .doOnError(span::recordException) @@ -620,7 +621,7 @@ private static Event buildResponseEvent( .content(Content.builder().role("user").parts(partFunctionResponse).build()) .actions(toolContext.eventActions()) .build(); - Tracing.traceToolResponse(invocationContext, event.id(), event); + Tracing.traceToolResponse(event.id(), event); return event; } finally { span.end(); diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 23054b674..36c6e3e58 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -26,9 +26,12 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; @@ -37,7 +40,9 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,8 +57,59 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = + AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); + + private static final AttributeKey GEN_AI_OPERATION_NAME = + AttributeKey.stringKey("gen_ai.operation.name"); + private static final AttributeKey GEN_AI_AGENT_DESCRIPTION = + AttributeKey.stringKey("gen_ai.agent.description"); + private static final AttributeKey GEN_AI_AGENT_NAME = + AttributeKey.stringKey("gen_ai.agent.name"); + private static final AttributeKey GEN_AI_CONVERSATION_ID = + AttributeKey.stringKey("gen_ai.conversation.id"); + private static final AttributeKey GEN_AI_SYSTEM = AttributeKey.stringKey("gen_ai.system"); + private static final AttributeKey GEN_AI_TOOL_CALL_ID = + AttributeKey.stringKey("gen_ai.tool_call.id"); + private static final AttributeKey GEN_AI_TOOL_DESCRIPTION = + AttributeKey.stringKey("gen_ai.tool.description"); + private static final AttributeKey GEN_AI_TOOL_NAME = + AttributeKey.stringKey("gen_ai.tool.name"); + private static final AttributeKey GEN_AI_TOOL_TYPE = + AttributeKey.stringKey("gen_ai.tool.type"); + private static final AttributeKey GEN_AI_REQUEST_MODEL = + AttributeKey.stringKey("gen_ai.request.model"); + private static final AttributeKey GEN_AI_REQUEST_TOP_P = + AttributeKey.doubleKey("gen_ai.request.top_p"); + private static final AttributeKey GEN_AI_REQUEST_MAX_TOKENS = + AttributeKey.longKey("gen_ai.request.max_tokens"); + private static final AttributeKey GEN_AI_USAGE_INPUT_TOKENS = + AttributeKey.longKey("gen_ai.usage.input_tokens"); + private static final AttributeKey GEN_AI_USAGE_OUTPUT_TOKENS = + AttributeKey.longKey("gen_ai.usage.output_tokens"); + + private static final AttributeKey ADK_TOOL_CALL_ARGS = + AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"); + private static final AttributeKey ADK_LLM_REQUEST = + AttributeKey.stringKey("gcp.vertex.agent.llm_request"); + private static final AttributeKey ADK_LLM_RESPONSE = + AttributeKey.stringKey("gcp.vertex.agent.llm_response"); + private static final AttributeKey ADK_INVOCATION_ID = + AttributeKey.stringKey("gcp.vertex.agent.invocation_id"); + private static final AttributeKey ADK_EVENT_ID = + AttributeKey.stringKey("gcp.vertex.agent.event_id"); + private static final AttributeKey ADK_TOOL_RESPONSE = + AttributeKey.stringKey("gcp.vertex.agent.tool_response"); + private static final AttributeKey ADK_SESSION_ID = + AttributeKey.stringKey("gcp.vertex.agent.session_id"); + private static final AttributeKey ADK_DATA = + AttributeKey.stringKey("gcp.vertex.agent.data"); + + private static final TypeReference> MAP_TYPE_REFERENCE = + new TypeReference>() {}; + @SuppressWarnings("NonFinalStaticField") - private static Tracer tracer = GlobalOpenTelemetry.getTracer("com.google.adk"); + private static Tracer tracer = GlobalOpenTelemetry.getTracer("gcp.vertex.agent"); private static final boolean CAPTURE_MESSAGE_CONTENT_IN_SPANS = Boolean.parseBoolean( @@ -66,52 +122,105 @@ public static void setTracerForTesting(Tracer tracer) { Tracing.tracer = tracer; } + /** + * Sets span attributes immediately available on agent invocation according to OTEL semconv + * version 1.37. + * + * @param span Span on which attributes are set. + * @param agentName Agent name from which attributes are gathered. + * @param agentDescription Agent description from which attributes are gathered. + * @param invocationContext InvocationContext from which attributes are gathered. + */ + public static void traceAgentInvocation( + Span span, String agentName, String agentDescription, InvocationContext invocationContext) { + span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); + span.setAttribute(GEN_AI_AGENT_NAME, agentName); + if (invocationContext.session() != null && invocationContext.session().id() != null) { + span.setAttribute(GEN_AI_CONVERSATION_ID, invocationContext.session().id()); + } + } + /** * Traces tool call arguments. * * @param args The arguments to the tool call. */ - public static void traceToolCall(Map args) { + public static void traceToolCall( + String toolName, String toolDescription, String toolType, Map args) { Span span = Span.current(); if (span == null || !span.getSpanContext().isValid()) { log.trace("traceToolCall: No valid span in current context."); return; } - span.setAttribute("gen_ai.system", "com.google.adk"); - try { - span.setAttribute("adk.tool_call_args", JsonBaseModel.getMapper().writeValueAsString(args)); - } catch (JsonProcessingException e) { - log.warn("traceToolCall: Failed to serialize tool call args to JSON", e); + span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + try { + span.setAttribute(ADK_TOOL_CALL_ARGS, JsonBaseModel.getMapper().writeValueAsString(args)); + } catch (JsonProcessingException e) { + log.warn("traceToolCall: Failed to serialize tool call args to JSON", e); + } + } else { + span.setAttribute(ADK_TOOL_CALL_ARGS, "{}"); } + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** * Traces tool response event. * - * @param invocationContext The invocation context for the current agent run. * @param eventId The ID of the event. * @param functionResponseEvent The function response event. */ - public static void traceToolResponse( - InvocationContext invocationContext, String eventId, Event functionResponseEvent) { + public static void traceToolResponse(String eventId, Event functionResponseEvent) { Span span = Span.current(); if (span == null || !span.getSpanContext().isValid()) { log.trace("traceToolResponse: No valid span in current context."); return; } - span.setAttribute("gen_ai.system", "com.google.adk"); - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - span.setAttribute("adk.event_id", eventId); - span.setAttribute("adk.tool_response", functionResponseEvent.toJson()); + span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); + span.setAttribute(ADK_EVENT_ID, eventId); - // Setting empty llm request and response (as the AdkDevServer UI expects these) - span.setAttribute("adk.llm_request", "{}"); - span.setAttribute("adk.llm_response", "{}"); - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); + String toolCallId = ""; + Object toolResponse = ""; + + Optional optionalFunctionResponse = + functionResponseEvent.functionResponses().stream().findFirst(); + + if (optionalFunctionResponse.isPresent()) { + FunctionResponse functionResponse = optionalFunctionResponse.get(); + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + if (!(toolResponse instanceof Map)) { + toolResponse = ImmutableMap.of("result", toolResponse); } + + if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + try { + span.setAttribute( + ADK_TOOL_RESPONSE, JsonBaseModel.getMapper().writeValueAsString(toolResponse)); + } catch (JsonProcessingException e) { + log.warn("traceToolResponse: Failed to serialize tool response to JSON", e); + span.setAttribute(ADK_TOOL_RESPONSE, "{\"error\": \"serialization failed\"}"); + } + } else { + span.setAttribute(ADK_TOOL_RESPONSE, "{}"); + } + + // Setting empty llm request and response (as the AdkDevServer UI expects these) + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** @@ -162,32 +271,62 @@ public static void traceCallLlm( return; } - span.setAttribute("gen_ai.system", "com.google.adk"); - llmRequest.model().ifPresent(modelName -> span.setAttribute("gen_ai.request.model", modelName)); - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - span.setAttribute("adk.event_id", eventId); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); + span.setAttribute(ADK_EVENT_ID, eventId); if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); + span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); } else { log.trace( "traceCallLlm: InvocationContext session or session ID is null, cannot set" - + " adk.session_id"); + + " gcp.vertex.agent.session_id"); } if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { try { span.setAttribute( - "adk.llm_request", + ADK_LLM_REQUEST, JsonBaseModel.getMapper().writeValueAsString(buildLlmRequestForTrace(llmRequest))); - span.setAttribute("adk.llm_response", llmResponse.toJson()); + span.setAttribute(ADK_LLM_RESPONSE, llmResponse.toJson()); } catch (JsonProcessingException e) { log.warn("traceCallLlm: Failed to serialize LlmRequest or LlmResponse to JSON", e); } } else { - span.setAttribute("adk.llm_request", "{}"); - span.setAttribute("adk.llm_response", "{}"); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -205,29 +344,27 @@ public static void traceSendData( return; } - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); + span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); if (eventId != null && !eventId.isEmpty()) { - span.setAttribute("adk.event_id", eventId); + span.setAttribute(ADK_EVENT_ID, eventId); } if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); + span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); } - - try { - List> dataList = new ArrayList<>(); - if (data != null) { - for (Content content : data) { - if (content != null) { - dataList.add( - JsonBaseModel.getMapper() - .convertValue(content, new TypeReference>() {})); - } - } + if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + try { + ImmutableList> dataList = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(content -> content != null) + .map(content -> JsonBaseModel.getMapper().convertValue(content, MAP_TYPE_REFERENCE)) + .collect(toImmutableList()); + span.setAttribute(ADK_DATA, JsonBaseModel.toJsonString(dataList)); + } catch (IllegalStateException e) { + log.warn("traceSendData: Failed to serialize data to JSON", e); } - span.setAttribute("adk.data", JsonBaseModel.toJsonString(dataList)); - } catch (IllegalStateException e) { - log.warn("traceSendData: Failed to serialize data to JSON", e); + } else { + span.setAttribute(ADK_DATA, "{}"); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index bf68f905c..5e2fa5792 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -94,6 +94,21 @@ public void rootAgent_returnsRootAgent() { assertThat(subSubAgent.rootAgent()).isEqualTo(agent); assertThat(subAgent.rootAgent()).isEqualTo(agent); assertThat(agent.rootAgent()).isEqualTo(agent); + assertThat(subSubAgent.parentAgent()).isEqualTo(subAgent); + assertThat(subAgent.parentAgent()).isEqualTo(agent); + assertThat(agent.parentAgent()).isNull(); + } + + @Test + public void subAgents_returnsSubAgents() { + TestBaseAgent subAgent1 = + new TestBaseAgent("subAgent1", "subAgent1", null, ImmutableList.of(), null, null); + TestBaseAgent subAgent2 = + new TestBaseAgent("subAgent2", "subAgent2", null, ImmutableList.of(), null, null); + TestBaseAgent agent = + new TestBaseAgent( + "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null); + assertThat(agent.subAgents()).containsExactly(subAgent1, subAgent2).inOrder(); } @Test @@ -354,6 +369,199 @@ public void runLive_invokesRunLiveImpl() { assertThat(runLiveCallback.wasCalled()).isTrue(); } + @Test + public void + runLive_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunLiveImplAndAfterCallback() { + var runLiveImpl = TestCallback.returningEmpty(); + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback = TestCallback.returning(callbackContent); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(callbackContent); + assertThat(runLiveImpl.wasCalled()).isFalse(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isFalse(); + } + + @Test + public void runLive_firstBeforeCallbackReturnsContent_skipsSecondBeforeCallback() { + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback1 = TestCallback.returning(callbackContent); + var beforeCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of( + beforeCallback1.asBeforeAgentCallback(), beforeCallback2.asBeforeAgentCallback()), + ImmutableList.of(), + TestCallback.returningEmpty().asRunLiveImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + var unused = agent.runLive(invocationContext).toList().blockingGet(); + assertThat(beforeCallback1.wasCalled()).isTrue(); + assertThat(beforeCallback2.wasCalled()).isFalse(); + } + + @Test + public void + runLive_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunLiveImplAndAfterCallbacks() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_afterCallbackReturnsContent_invokesRunLiveImplAndAfterCallbacksAndReturnsAllContent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returning(afterCallbackContent); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_beforeCallbackMutatesStateAndReturnsEmpty_invokesRunLiveImplAndReturnsStateEvent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + BeforeAgentCallback beforeCallback = + new BeforeAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // State event from before callback + assertThat(results.get(0).content()).isEmpty(); + assertThat(results.get(0).actions().stateDelta()).containsEntry("key", "value"); + // Content event from runLiveImpl + assertThat(results.get(1).content()).hasValue(runLiveImplContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_afterCallbackMutatesStateAndReturnsEmpty_invokesRunLiveImplAndReturnsStateEvent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + AfterAgentCallback afterCallback = + new AfterAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // Content event from runLiveImpl + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + // State event from after callback + assertThat(results.get(1).content()).isEmpty(); + assertThat(results.get(1).actions().stateDelta()).containsEntry("key", "value"); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + } + + @Test + public void runLive_firstAfterCallbackReturnsContent_skipsSecondAfterCallback() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var afterCallback1 = TestCallback.returning(afterCallbackContent); + var afterCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(), + ImmutableList.of( + afterCallback1.asAfterAgentCallback(), afterCallback2.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(afterCallback1.wasCalled()).isTrue(); + assertThat(afterCallback2.wasCalled()).isFalse(); + } + @Test public void constructor_invalidName_throwsIllegalArgumentException() { assertThrows( @@ -379,6 +587,12 @@ public void constructor_duplicateSubAgentNames_throwsIllegalArgumentException() "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null)); } + @Test + @SuppressWarnings("DoNotCall") + public void fromConfig_throwsUnsupportedOperationException() { + assertThrows(UnsupportedOperationException.class, () -> BaseAgent.fromConfig(null, null)); + } + @Test public void close_noSubAgents_completesSuccessfully() { ClosableTestAgent agent = new ClosableTestAgent("agent", "description", ImmutableList.of()); diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 438522493..7d493c526 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,22 +16,39 @@ package com.google.adk.telemetry; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; - -import io.opentelemetry.api.GlobalOpenTelemetry; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; -import io.opentelemetry.sdk.OpenTelemetrySdk; -import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter; -import io.opentelemetry.sdk.trace.SdkTracerProvider; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; -import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; +import io.reactivex.rxjava3.core.Flowable; import java.util.List; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; /** * Tests for OpenTelemetry context propagation in ADK. @@ -39,31 +56,28 @@ *

    Verifies that spans created by ADK properly link to parent contexts when available, enabling * proper distributed tracing across async boundaries. */ -class ContextPropagationTest { +@RunWith(JUnit4.class) +public class ContextPropagationTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); - private InMemorySpanExporter spanExporter; private Tracer tracer; + private Tracer originalTracer; + + @Before + public void setup() { + this.originalTracer = Tracing.getTracer(); + Tracing.setTracerForTesting( + openTelemetryRule.getOpenTelemetry().getTracer("ContextPropagationTest")); + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test"); + } - @BeforeEach - void setup() { - // Reset GlobalOpenTelemetry state - GlobalOpenTelemetry.resetForTest(); - - spanExporter = InMemorySpanExporter.create(); - - SdkTracerProvider tracerProvider = - SdkTracerProvider.builder() - .addSpanProcessor(SimpleSpanProcessor.create(spanExporter)) - .build(); - - OpenTelemetrySdk sdk = - OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).buildAndRegisterGlobal(); - - tracer = sdk.getTracer("test"); + @After + public void tearDown() { + Tracing.setTracerForTesting(originalTracer); } @Test - void testToolCallSpanLinksToParent() { + public void testToolCallSpanLinksToParent() { // Given: Parent span is active Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -82,8 +96,8 @@ void testToolCallSpanLinksToParent() { } // Then: tool_call should be child of parent - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: parent and tool_call"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans: parent and tool_call", 2, spans.size()); SpanData parentSpanData = spans.stream() @@ -99,41 +113,43 @@ void testToolCallSpanLinksToParent() { // Verify parent-child relationship assertEquals( + "Tool call should have same trace ID as parent", parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId(), - "Tool call should have same trace ID as parent"); + toolCallSpanData.getSpanContext().getTraceId()); assertEquals( + "Tool call's parent should be the parent span", parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId(), - "Tool call's parent should be the parent span"); + toolCallSpanData.getParentSpanContext().getSpanId()); } @Test - void testToolCallWithoutParentCreatesRootSpan() { + public void testToolCallWithoutParentCreatesRootSpan() { // Given: No parent span active // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); + try (Scope s = Context.root().makeCurrent()) { + Span toolCallSpan = + tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); + try (Scope scope = toolCallSpan.makeCurrent()) { + // Work + } finally { + toolCallSpan.end(); + } } // Then: Should create root span (backward compatible) - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(1, spans.size(), "Should have exactly 1 span"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have exactly 1 span", 1, spans.size()); SpanData toolCallSpanData = spans.get(0); assertFalse( - toolCallSpanData.getParentSpanContext().isValid(), - "Tool call should be root span when no parent exists"); + "Tool call should be root span when no parent exists", + toolCallSpanData.getParentSpanContext().isValid()); } @Test - void testNestedSpanHierarchy() { + public void testNestedSpanHierarchy() { // Test: parent → invocation → tool_call → tool_response hierarchy Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -168,8 +184,8 @@ void testNestedSpanHierarchy() { } // Verify complete hierarchy - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(4, spans.size(), "Should have 4 spans in the hierarchy"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); String parentTraceId = spans.stream() @@ -182,9 +198,9 @@ void testNestedSpanHierarchy() { spans.forEach( span -> assertEquals( + "All spans should be in same trace", parentTraceId, - span.getSpanContext().getTraceId(), - "All spans should be in same trace")); + span.getSpanContext().getTraceId())); // Verify parent-child relationships SpanData parentSpanData = findSpanByName(spans, "parent"); @@ -194,25 +210,25 @@ void testNestedSpanHierarchy() { // invocation should be child of parent assertEquals( + "Invocation should be child of parent", parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId(), - "Invocation should be child of parent"); + invocationSpanData.getParentSpanContext().getSpanId()); // tool_call should be child of invocation assertEquals( + "Tool call should be child of invocation", invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId(), - "Tool call should be child of invocation"); + toolCallSpanData.getParentSpanContext().getSpanId()); // tool_response should be child of tool_call assertEquals( + "Tool response should be child of tool call", toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId(), - "Tool response should be child of tool call"); + toolResponseSpanData.getParentSpanContext().getSpanId()); } @Test - void testMultipleSpansInParallel() { + public void testMultipleSpansInParallel() { // Test: Multiple tool calls in parallel should all link to same parent Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -234,8 +250,8 @@ void testMultipleSpansInParallel() { } // Verify all tool calls link to same parent - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(4, spans.size(), "Should have 4 spans: 1 parent + 3 tool calls"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 4 spans: 1 parent + 3 tool calls", 4, spans.size()); SpanData parentSpanData = findSpanByName(spans, "parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -245,59 +261,59 @@ void testMultipleSpansInParallel() { List toolCallSpans = spans.stream().filter(s -> s.getName().startsWith("tool_call")).toList(); - assertEquals(3, toolCallSpans.size(), "Should have 3 tool call spans"); + assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); toolCallSpans.forEach( span -> { assertEquals( + "Tool call should have same trace ID as parent", parentTraceId, - span.getSpanContext().getTraceId(), - "Tool call should have same trace ID as parent"); + span.getSpanContext().getTraceId()); assertEquals( + "Tool call should have parent as parent span", parentSpanId, - span.getParentSpanContext().getSpanId(), - "Tool call should have parent as parent span"); + span.getParentSpanContext().getSpanId()); }); } @Test - void testAgentRunSpanLinksToInvocation() { - // Test: agent_run span should link to invocation span + public void testInvokeAgentSpanLinksToInvocation() { + // Test: invoke_agent span should link to invocation span Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span agentRunSpan = - tracer.spanBuilder("agent_run [test-agent]").setParent(Context.current()).startSpan(); + Span invokeAgentSpan = + tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - try (Scope agentScope = agentRunSpan.makeCurrent()) { + try (Scope agentScope = invokeAgentSpan.makeCurrent()) { // Simulate agent work } finally { - agentRunSpan.end(); + invokeAgentSpan.end(); } } finally { invocationSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: invocation and agent_run"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans: invocation and invoke_agent", 2, spans.size()); SpanData invocationSpanData = findSpanByName(spans, "invocation"); - SpanData agentRunSpanData = findSpanByName(spans, "agent_run [test-agent]"); + SpanData invokeAgentSpanData = findSpanByName(spans, "invoke_agent test-agent"); assertEquals( + "Agent run should be child of invocation", invocationSpanData.getSpanContext().getSpanId(), - agentRunSpanData.getParentSpanContext().getSpanId(), - "Agent run should be child of invocation"); + invokeAgentSpanData.getParentSpanContext().getSpanId()); } @Test - void testCallLlmSpanLinksToAgentRun() { + public void testCallLlmSpanLinksToAgentRun() { // Test: call_llm span should link to agent_run span - Span agentRunSpan = tracer.spanBuilder("agent_run [test-agent]").startSpan(); + Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - try (Scope agentScope = agentRunSpan.makeCurrent()) { + try (Scope agentScope = invokeAgentSpan.makeCurrent()) { Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); try (Scope llmScope = callLlmSpan.makeCurrent()) { @@ -306,43 +322,76 @@ void testCallLlmSpanLinksToAgentRun() { callLlmSpan.end(); } } finally { - agentRunSpan.end(); + invokeAgentSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: agent_run and call_llm"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans: invoke_agent and call_llm", 2, spans.size()); - SpanData agentRunSpanData = findSpanByName(spans, "agent_run [test-agent]"); + SpanData invokeAgentSpanData = findSpanByName(spans, "invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName(spans, "call_llm"); assertEquals( - agentRunSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId(), - "Call LLM should be child of agent run"); + "Call LLM should be child of agent run", + invokeAgentSpanData.getSpanContext().getSpanId(), + callLlmSpanData.getParentSpanContext().getSpanId()); } @Test - void testSpanCreatedWithinParentScopeIsCorrectlyParented() { + public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { // Test: Simulates creating a span within the scope of a parent Span parentSpan = tracer.spanBuilder("invocation").startSpan(); try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("agent_run").setParent(Context.current()).startSpan(); + Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); agentSpan.end(); } finally { parentSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans", 2, spans.size()); SpanData parentSpanData = findSpanByName(spans, "invocation"); - SpanData agentSpanData = findSpanByName(spans, "agent_run"); + SpanData agentSpanData = findSpanByName(spans, "invoke_agent"); + + assertEquals( + "Agent span should be a child of the invocation span", + parentSpanData.getSpanContext().getSpanId(), + agentSpanData.getParentSpanContext().getSpanId()); + } + @Test + public void testTraceFlowable() throws InterruptedException { + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Span flowableSpan = tracer.spanBuilder("flowable").setParent(Context.current()).startSpan(); + Flowable flowable = + Tracing.traceFlowable( + Context.current().with(flowableSpan), + flowableSpan, + () -> + Flowable.just(1, 2, 3) + .map( + i -> { + assertEquals( + flowableSpan.getSpanContext().getSpanId(), + Span.current().getSpanContext().getSpanId()); + return i * 2; + })); + flowable.test().await().assertComplete(); + } finally { + parentSpan.end(); + } + + List spans = openTelemetryRule.getSpans(); + assertEquals(2, spans.size()); + SpanData parentSpanData = findSpanByName(spans, "parent"); + SpanData flowableSpanData = findSpanByName(spans, "flowable"); assertEquals( parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId(), - "Agent span should be a child of the invocation span"); + flowableSpanData.getParentSpanContext().getSpanId()); + assertTrue(flowableSpanData.hasEnded()); } private SpanData findSpanByName(List spans, String name) { @@ -351,4 +400,164 @@ private SpanData findSpanByName(List spans, String name) { .findFirst() .orElseThrow(() -> new AssertionError("Span not found: " + name)); } + + @Test + public void testTraceAgentInvocation() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceAgentInvocation( + span, + "test-agent", + "test-description", + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build()); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("test-agent", attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))); + assertEquals("test-description", attrs.get(AttributeKey.stringKey("gen_ai.agent.description"))); + assertEquals("session-1", attrs.get(AttributeKey.stringKey("gen_ai.conversation.id"))); + } + + @Test + public void testTraceToolCall() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceToolCall( + "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals( + "{\"arg1\":\"value1\"}", + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_request"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_response"))); + } + + @Test + public void testTraceToolResponse() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Event functionResponseEvent = + Event.builder() + .id("event-1") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("tool-name") + .id("tool-call-id") + .response(ImmutableMap.of("result", "tool-result")) + .build()) + .build())) + .build(); + Tracing.traceToolResponse("event-1", functionResponseEvent); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals( + "{\"result\":\"tool-result\"}", + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); + } + + @Test + public void testTraceCallLlm() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-pro") + .contents(ImmutableList.of(Content.fromParts(Part.fromText("hello")))) + .config(GenerateContentConfig.builder().topP(0.9f).maxOutputTokens(100).build()) + .build(); + LlmResponse llmResponse = + LlmResponse.builder() + .content(Content.builder().parts(Part.fromText("world")).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build()) + .build(); + Tracing.traceCallLlm( + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build(), + "event-1", + llmRequest, + llmResponse); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); + assertEquals("inv-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertEquals(0.9d, attrs.get(AttributeKey.doubleKey("gen_ai.request.top_p")), 0.01); + assertEquals(100L, (long) attrs.get(AttributeKey.longKey("gen_ai.request.max_tokens"))); + assertEquals(10L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.input_tokens"))); + assertEquals(20L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.output_tokens"))); + assertEquals( + ImmutableList.of("stop"), + attrs.get(AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"))); + assertTrue( + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_request")).contains("gemini-pro")); + assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_response")).contains("STOP")); + } + + @Test + public void testTraceSendData() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceSendData( + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build(), + "event-1", + ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("inv-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.data")).contains("hello")); + } } diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 434d85e6f..6f35f5a3c 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -115,6 +115,13 @@ public Supplier> asRunLiveImplSupplier(Content content) { }); } + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + */ + public Supplier> asRunLiveImplSupplier(String contentText) { + return asRunLiveImplSupplier(Content.fromParts(Part.fromText(contentText))); + } + @SuppressWarnings("unchecked") // This cast is safe if T is Content. public BeforeAgentCallback asBeforeAgentCallback() { return (unusedCtx) -> (Maybe) callMaybe(); From 968a9a8944bd7594efc51ed0b5201804133f350e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 10 Feb 2026 14:17:47 -0800 Subject: [PATCH 44/63] feat: Add ContextCacheConfig to InvocationContext This change introduces ContextCacheConfig to the InvocationContext and Runner classes, allowing context cache settings to be configured per invocation. PiperOrigin-RevId: 868315679 --- .../google/adk/agents/InvocationContext.java | 23 +++++++++++++++++++ .../java/com/google/adk/runner/Runner.java | 14 +++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 3b460b073..2a0cc90c9 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -56,6 +56,7 @@ public class InvocationContext { private final Map endOfAgents; private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; + @Nullable private final ContextCacheConfig contextCacheConfig; private final InvocationCostManager invocationCostManager; private final Map callbackContextData; @@ -81,6 +82,7 @@ protected InvocationContext(Builder builder) { this.endOfAgents = builder.endOfAgents; this.resumabilityConfig = builder.resumabilityConfig; this.eventsCompactionConfig = builder.eventsCompactionConfig; + this.contextCacheConfig = builder.contextCacheConfig; this.invocationCostManager = builder.invocationCostManager; this.callbackContextData = builder.callbackContextData; } @@ -400,6 +402,11 @@ public Optional eventsCompactionConfig() { return Optional.ofNullable(eventsCompactionConfig); } + /** Returns the context cache configuration for the current agent run. */ + public Optional contextCacheConfig() { + return Optional.ofNullable(contextCacheConfig); + } + /** Returns whether to pause the invocation right after this [event]. */ public boolean shouldPauseInvocation(Event event) { if (!isResumable()) { @@ -472,6 +479,7 @@ private Builder(InvocationContext context) { this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents); this.resumabilityConfig = context.resumabilityConfig; this.eventsCompactionConfig = context.eventsCompactionConfig; + this.contextCacheConfig = context.contextCacheConfig; this.invocationCostManager = context.invocationCostManager; this.callbackContextData = context.callbackContextData; } @@ -493,6 +501,7 @@ private Builder(InvocationContext context) { private Map endOfAgents = new ConcurrentHashMap<>(); private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); @Nullable private EventsCompactionConfig eventsCompactionConfig; + @Nullable private ContextCacheConfig contextCacheConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); private Map callbackContextData = new ConcurrentHashMap<>(); @@ -730,6 +739,18 @@ public Builder eventsCompactionConfig(@Nullable EventsCompactionConfig eventsCom return this; } + /** + * Sets the context cache configuration for the current agent run. + * + * @param contextCacheConfig the context cache configuration. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder contextCacheConfig(@Nullable ContextCacheConfig contextCacheConfig) { + this.contextCacheConfig = contextCacheConfig; + return this; + } + /** * Sets the callback context data for the invocation. * @@ -778,6 +799,7 @@ public boolean equals(Object o) { && Objects.equals(endOfAgents, that.endOfAgents) && Objects.equals(resumabilityConfig, that.resumabilityConfig) && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) + && Objects.equals(contextCacheConfig, that.contextCacheConfig) && Objects.equals(invocationCostManager, that.invocationCostManager) && Objects.equals(callbackContextData, that.callbackContextData); } @@ -802,6 +824,7 @@ public int hashCode() { endOfAgents, resumabilityConfig, eventsCompactionConfig, + contextCacheConfig, invocationCostManager, callbackContextData); } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 09b9752a4..31026ee6e 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -18,6 +18,7 @@ import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.ContextCacheConfig; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -75,6 +76,7 @@ public class Runner { private final PluginManager pluginManager; private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; + @Nullable private final ContextCacheConfig contextCacheConfig; /** Builder for {@link Runner}. */ public static class Builder { @@ -138,6 +140,7 @@ public Runner build() { List buildPlugins; ResumabilityConfig buildResumabilityConfig; EventsCompactionConfig buildEventsCompactionConfig; + ContextCacheConfig buildContextCacheConfig; if (this.app != null) { if (this.agent != null) { @@ -154,12 +157,14 @@ public Runner build() { ? this.app.resumabilityConfig() : new ResumabilityConfig(); buildEventsCompactionConfig = this.app.eventsCompactionConfig(); + buildContextCacheConfig = this.app.contextCacheConfig(); } else { buildAgent = this.agent; buildAppName = this.appName; buildPlugins = this.plugins; buildResumabilityConfig = new ResumabilityConfig(); buildEventsCompactionConfig = null; + buildContextCacheConfig = null; } if (buildAgent == null) { @@ -182,7 +187,8 @@ public Runner build() { memoryService, buildPlugins, buildResumabilityConfig, - buildEventsCompactionConfig); + buildEventsCompactionConfig, + buildContextCacheConfig); } } @@ -257,6 +263,7 @@ public Runner( memoryService, plugins, resumabilityConfig, + null, null); } @@ -274,7 +281,8 @@ protected Runner( @Nullable BaseMemoryService memoryService, List plugins, ResumabilityConfig resumabilityConfig, - @Nullable EventsCompactionConfig eventsCompactionConfig) { + @Nullable EventsCompactionConfig eventsCompactionConfig, + @Nullable ContextCacheConfig contextCacheConfig) { this.agent = agent; this.appName = appName; this.artifactService = artifactService; @@ -283,6 +291,7 @@ protected Runner( this.pluginManager = new PluginManager(plugins); this.resumabilityConfig = resumabilityConfig; this.eventsCompactionConfig = createEventsCompactionConfig(agent, eventsCompactionConfig); + this.contextCacheConfig = contextCacheConfig; } /** @@ -644,6 +653,7 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .session(session) .resumabilityConfig(this.resumabilityConfig) .eventsCompactionConfig(this.eventsCompactionConfig) + .contextCacheConfig(this.contextCacheConfig) .agent(this.findAgentToRun(session, rootAgent)); } From 4b6ebded7539a2a667a1eee8860627bfa3a7180e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Feb 2026 03:19:08 -0800 Subject: [PATCH 45/63] refactor: Deprecating endOfInvocation in EventActions Python and Vertex Session Store moved to use endOfAgent, so we should use that as well. PiperOrigin-RevId: 868587270 --- .../com/google/adk/events/EventActions.java | 34 +++++++++++-------- .../adk/sessions/SessionJsonConverter.java | 3 +- .../google/adk/events/EventActionsTest.java | 4 +-- .../sessions/SessionJsonConverterTest.java | 2 +- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6543ec823..157deda63 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -46,7 +46,6 @@ public class EventActions extends JsonBaseModel { private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; private ConcurrentMap agentState; - private Optional endInvocation; private Optional compaction; private Optional rewindBeforeInvocationId; @@ -61,7 +60,6 @@ public EventActions() { this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endOfAgent = false; - this.endInvocation = Optional.empty(); this.compaction = Optional.empty(); this.agentState = new ConcurrentHashMap<>(); this.rewindBeforeInvocationId = Optional.empty(); @@ -77,7 +75,6 @@ private EventActions(Builder builder) { this.requestedAuthConfigs = builder.requestedAuthConfigs; this.requestedToolConfirmations = builder.requestedToolConfirmations; this.endOfAgent = builder.endOfAgent; - this.endInvocation = builder.endInvocation; this.compaction = builder.compaction; this.agentState = builder.agentState; this.rewindBeforeInvocationId = builder.rewindBeforeInvocationId; @@ -194,17 +191,28 @@ public void setEndOfAgent(boolean endOfAgent) { this.endOfAgent = endOfAgent; } - @JsonProperty("endInvocation") + /** + * @deprecated Use {@link #endOfAgent()} instead. + */ + @Deprecated public Optional endInvocation() { - return endInvocation; + return endOfAgent ? Optional.of(true) : Optional.empty(); } + /** + * @deprecated Use {@link #setEndOfAgent(boolean)} instead. + */ + @Deprecated public void setEndInvocation(Optional endInvocation) { - this.endInvocation = endInvocation; + this.endOfAgent = endInvocation.orElse(false); } + /** + * @deprecated Use {@link #setEndOfAgent(boolean)} instead. + */ + @Deprecated public void setEndInvocation(boolean endInvocation) { - this.endInvocation = Optional.of(endInvocation); + this.endOfAgent = endInvocation; } @JsonProperty("compaction") @@ -260,7 +268,6 @@ public boolean equals(Object o) { && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) && Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations) && (endOfAgent == that.endOfAgent) - && Objects.equals(endInvocation, that.endInvocation) && Objects.equals(compaction, that.compaction) && Objects.equals(agentState, that.agentState) && Objects.equals(rewindBeforeInvocationId, that.rewindBeforeInvocationId); @@ -278,7 +285,6 @@ public int hashCode() { requestedAuthConfigs, requestedToolConfirmations, endOfAgent, - endInvocation, compaction, agentState, rewindBeforeInvocationId); @@ -295,7 +301,6 @@ public static class Builder { private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; - private Optional endInvocation; private Optional compaction; private ConcurrentMap agentState; private Optional rewindBeforeInvocationId; @@ -309,7 +314,6 @@ public Builder() { this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); - this.endInvocation = Optional.empty(); this.compaction = Optional.empty(); this.agentState = new ConcurrentHashMap<>(); this.rewindBeforeInvocationId = Optional.empty(); @@ -326,7 +330,6 @@ private Builder(EventActions eventActions) { this.requestedToolConfirmations = new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); this.endOfAgent = eventActions.endOfAgent(); - this.endInvocation = eventActions.endInvocation(); this.compaction = eventActions.compaction(); this.agentState = new ConcurrentHashMap<>(eventActions.agentState()); this.rewindBeforeInvocationId = eventActions.rewindBeforeInvocationId(); @@ -396,10 +399,14 @@ public Builder endOfAgent(boolean endOfAgent) { return this; } + /** + * @deprecated Use {@link #endOfAgent(boolean)} instead. + */ @CanIgnoreReturnValue @JsonProperty("endInvocation") + @Deprecated public Builder endInvocation(boolean endInvocation) { - this.endInvocation = Optional.of(endInvocation); + this.endOfAgent = endInvocation; return this; } @@ -435,7 +442,6 @@ public Builder merge(EventActions other) { this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); this.requestedToolConfirmations.putAll(other.requestedToolConfirmations()); this.endOfAgent = other.endOfAgent(); - other.endInvocation().ifPresent(this::endInvocation); other.compaction().ifPresent(this::compaction); this.agentState.putAll(other.agentState()); other.rewindBeforeInvocationId().ifPresent(this::rewindBeforeInvocationId); diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 3781ef537..f39605e3a 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -109,7 +109,7 @@ static String convertEventToJson(Event event, boolean useIsoString) { actionsJson.put("transferAgent", v); }); actions.escalate().ifPresent(v -> actionsJson.put("escalate", v)); - actions.endInvocation().ifPresent(v -> actionsJson.put("endOfAgent", v)); + actionsJson.put("endOfAgent", actions.endOfAgent()); putIfNotEmpty(actionsJson, "requestedAuthConfigs", actions.requestedAuthConfigs()); putIfNotEmpty( actionsJson, "requestedToolConfirmations", actions.requestedToolConfirmations()); @@ -182,7 +182,6 @@ static Event fromApiEvent(Map apiEvent) { Boolean endOfAgent = (Boolean) actionsMap.get("endOfAgent"); if (endOfAgent != null) { eventActionsBuilder.endOfAgent(endOfAgent); - eventActionsBuilder.endInvocation(endOfAgent); } eventActionsBuilder.requestedAuthConfigs( Optional.ofNullable(actionsMap.get("requestedAuthConfigs")) diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 9ea88b40a..7a58de575 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -84,7 +84,7 @@ public void merge_mergesAllFields() { ImmutableMap.of("config2", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) .requestedToolConfirmations( new ConcurrentHashMap<>(ImmutableMap.of("tool2", TOOL_CONFIRMATION))) - .endInvocation(true) + .endOfAgent(true) .build(); EventActions merged = eventActions1.toBuilder().merge(eventActions2).build(); @@ -103,7 +103,7 @@ public void merge_mergesAllFields() { new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))); assertThat(merged.requestedToolConfirmations()) .containsExactly("tool1", TOOL_CONFIRMATION, "tool2", TOOL_CONFIRMATION); - assertThat(merged.endInvocation()).hasValue(true); + assertThat(merged.endOfAgent()).isTrue(); assertThat(merged.compaction()).hasValue(COMPACTION); } diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index 3d1a845ed..f947a7140 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -189,7 +189,7 @@ public void convertEventToJson_complexActions_success() throws JsonProcessingExc EventActions.builder() .requestedAuthConfigs(authConfigs) .requestedToolConfirmations(toolConfirmations) - .endInvocation(true) + .endOfAgent(true) .build(); GenerateContentResponseUsageMetadata usageMetadata = From 4f2b5de5917272abadb134fbea412988cad70a3c Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 11 Feb 2026 05:00:29 -0800 Subject: [PATCH 46/63] chore: run validation workflow in release profile PiperOrigin-RevId: 868617839 --- .github/workflows/validation.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index eeb16e1ff..d9035a579 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -37,8 +37,8 @@ jobs: ${{ runner.os }}-maven-${{ matrix.java-version }}- ${{ runner.os }}-maven- - - name: Compile and test (all) modules with Java ${{ matrix.java-version }} - run: ./mvnw clean test + - name: Package and test (all) modules with Java ${{ matrix.java-version }} + run: ./mvnw -Prelease clean package - name: Detected wrongly formatted files run: git status && git diff --exit-code From a364e302006f2c93b0b1feecc5adf38e3a55cd7b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Feb 2026 05:14:19 -0800 Subject: [PATCH 47/63] test: Adding tests to make sure that tracing works across threads This testing found an improvement in how span propogation works in Function Calling. I would like to do a significant refactoring to cleanup all of the sprinkled tracing code. This step is necessary to confirm proper behavior before the refactor. PiperOrigin-RevId: 868623517 --- .../google/adk/flows/llmflows/Functions.java | 122 +++++----- .../com/google/adk/agents/LlmAgentTest.java | 142 ++++++++++++ .../adk/telemetry/ContextPropagationTest.java | 214 ++++++++++++------ 3 files changed, 360 insertions(+), 118 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 5a855445a..75289bd38 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -246,34 +246,41 @@ private static Function> getFunctionCallMapper( Map toolConfirmations, boolean isLive) { Context parentContext = Context.current(); - return functionCall -> { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation(functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = functionCall.args().orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - return isLive - ? processFunctionLive( - invocationContext, tool, toolContext, functionCall, functionArgs) - : callTool(tool, functionArgs, toolContext); - } - })); - - try (Scope scope = parentContext.makeCurrent()) { - return postProcessFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); - } - }; + return functionCall -> + Maybe.defer( + () -> { + try (Scope scope = parentContext.makeCurrent()) { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = functionCall.args().orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> { + try (Scope innerScope = parentContext.makeCurrent()) { + return isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs) + : callTool(tool, functionArgs, toolContext); + } + })); + + return postProcessFunctionResult( + maybeFunctionResult, invocationContext, tool, functionArgs, toolContext); + } + }); } /** @@ -376,42 +383,49 @@ private static Maybe postProcessFunctionResult( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext, - boolean isLive) { + ToolContext toolContext) { Context parentContext = Context.current(); return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( t -> - handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t) - .map(isLive ? Optional::ofNullable : Optional::of) + Maybe.defer( + () -> { + try (Scope scope = parentContext.makeCurrent()) { + return handleOnToolErrorCallback( + invocationContext, tool, functionArgs, toolContext, t); + } + }) + .map(Optional::ofNullable) .switchIfEmpty(Single.error(t))) .flatMapMaybe( optionalInitialResult -> { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - Maybe> afterToolResultMaybe = - maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult); - - return afterToolResultMaybe - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); + try (Scope scope = parentContext.makeCurrent()) { + Map initialFunctionResult = optionalInitialResult.orElse(null); + + Maybe> afterToolResultMaybe = + maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult); + + return afterToolResultMaybe + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + try (Scope innerScope = parentContext.makeCurrent()) { + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event functionResponseEvent = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + return Maybe.just(functionResponseEvent); } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); - } - }); + }); + } }); } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 760f67c7b..ce8be8dfb 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -22,8 +22,10 @@ import static com.google.adk.testing.TestUtils.createTestAgent; import static com.google.adk.testing.TestUtils.createTestAgentBuilder; import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterModelCallback; @@ -39,16 +41,22 @@ import com.google.adk.models.Model; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; +import com.google.adk.telemetry.Tracing; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -56,6 +64,9 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -63,6 +74,20 @@ /** Unit tests for {@link LlmAgent}. */ @RunWith(JUnit4.class) public final class LlmAgentTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + + private Tracer originalTracer; + + @Before + public void setup() { + this.originalTracer = Tracing.getTracer(); + Tracing.setTracerForTesting(openTelemetryRule.getOpenTelemetry().getTracer("gcp.vertex.agent")); + } + + @After + public void tearDown() { + Tracing.setTracerForTesting(originalTracer); + } private static class ClosableToolset implements BaseToolset { final AtomicBoolean closed = new AtomicBoolean(false); @@ -496,4 +521,121 @@ public void close() { assertThat(toolset1.closed.get()).isTrue(); assertThat(toolset2.closed.get()).isTrue(); } + + @Test + public void runAsync_createsInvokeAgentSpan() throws InterruptedException { + Content modelContent = Content.fromParts(Part.fromText("response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgent(testLlm); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent test agent"))) + .isTrue(); + } + + @Test + public void runAsync_withTools_createsToolSpans() throws InterruptedException { + ImmutableMap echoArgs = ImmutableMap.of("arg", "value"); + Content contentWithFunctionCall = + Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs)); + Content finalResponse = Content.fromParts(Part.fromText("finished")); + TestLlm testLlm = + createTestLlm(createLlmResponse(contentWithFunctionCall), createLlmResponse(finalResponse)); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(new EchoTool()).build(); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); + List llmSpans = findSpansByName(spans, "call_llm"); + List toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]"); + List toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]"); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + String agentSpanId = agentSpan.getSpanContext().getSpanId(); + llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + } + + @Test + public void runAsync_afterToolCallback_propagatesContext() throws InterruptedException { + ImmutableMap echoArgs = ImmutableMap.of("arg", "value"); + Content contentWithFunctionCall = + Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs)); + Content finalResponse = Content.fromParts(Part.fromText("finished")); + TestLlm testLlm = + createTestLlm(createLlmResponse(contentWithFunctionCall), createLlmResponse(finalResponse)); + + AfterToolCallback afterToolCallback = + (invCtx, tool, input, toolCtx, response) -> { + // Verify that the OpenTelemetry context is correctly propagated to the callback. + assertThat(Span.current().getSpanContext().isValid()).isTrue(); + return Maybe.empty(); + }; + + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools(new EchoTool()) + .afterToolCallback(ImmutableList.of(afterToolCallback)) + .build(); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + findSpanByName(spans, "invoke_agent test agent"); + } + + @Test + public void runAsync_withSubAgents_createsSpans() throws InterruptedException { + LlmAgent subAgent = + createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response"))) + .name("sub-agent") + .build(); + + // Force a transfer to sub-agent using a callback + AfterModelCallback transferCallback = + (ctx, response) -> { + ctx.eventActions().setTransferToAgent(subAgent.name()); + return Maybe.empty(); + }; + + TestLlm testLlm = createTestLlm(createTextLlmResponse("initial")); + LlmAgent agent = + createTestAgentBuilder(testLlm) + .subAgents(subAgent) + .afterModelCallback(ImmutableList.of(transferCallback)) + .build(); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent test agent"))) + .isTrue(); + assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent sub-agent"))).isTrue(); + + List llmSpans = findSpansByName(spans, "call_llm"); + assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent + } + + private List findSpansByName(List spans, String name) { + return spans.stream().filter(s -> s.getName().equals(name)).toList(); + } + + @CanIgnoreReturnValue + private SpanData findSpanByName(List spans, String name) { + return spans.stream() + .filter(s -> s.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new AssertionError("Span not found: " + name)); + } } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 7d493c526..3588a5b53 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -20,10 +20,14 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LiveRequestQueue; +import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.runner.Runner; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -42,7 +46,9 @@ import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; +import java.util.Optional; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -96,20 +102,8 @@ public void testToolCallSpanLinksToParent() { } // Then: tool_call should be child of parent - List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans: parent and tool_call", 2, spans.size()); - - SpanData parentSpanData = - spans.stream() - .filter(s -> s.getName().equals("parent")) - .findFirst() - .orElseThrow(() -> new AssertionError("Parent span not found")); - - SpanData toolCallSpanData = - spans.stream() - .filter(s -> s.getName().equals("tool_call [testTool]")) - .findFirst() - .orElseThrow(() -> new AssertionError("Tool call span not found")); + SpanData parentSpanData = findSpanByName("parent"); + SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); // Verify parent-child relationship assertEquals( @@ -184,29 +178,19 @@ public void testNestedSpanHierarchy() { } // Verify complete hierarchy - List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); - - String parentTraceId = - spans.stream() - .filter(s -> s.getName().equals("parent")) - .findFirst() - .map(s -> s.getSpanContext().getTraceId()) - .orElseThrow(() -> new AssertionError("Parent span not found")); + SpanData parentSpanData = findSpanByName("parent"); + String parentTraceId = parentSpanData.getSpanContext().getTraceId(); // All spans should have same trace ID - spans.forEach( - span -> - assertEquals( - "All spans should be in same trace", - parentTraceId, - span.getSpanContext().getTraceId())); + for (SpanData span : openTelemetryRule.getSpans()) { + assertEquals( + "All spans should be in same trace", parentTraceId, span.getSpanContext().getTraceId()); + } // Verify parent-child relationships - SpanData parentSpanData = findSpanByName(spans, "parent"); - SpanData invocationSpanData = findSpanByName(spans, "invocation"); - SpanData toolCallSpanData = findSpanByName(spans, "tool_call [testTool]"); - SpanData toolResponseSpanData = findSpanByName(spans, "tool_response [testTool]"); + SpanData invocationSpanData = findSpanByName("invocation"); + SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); + SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); // invocation should be child of parent assertEquals( @@ -250,16 +234,15 @@ public void testMultipleSpansInParallel() { } // Verify all tool calls link to same parent - List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 4 spans: 1 parent + 3 tool calls", 4, spans.size()); - - SpanData parentSpanData = findSpanByName(spans, "parent"); + SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); String parentSpanId = parentSpanData.getSpanContext().getSpanId(); // All tool calls should have same trace ID and parent span ID List toolCallSpans = - spans.stream().filter(s -> s.getName().startsWith("tool_call")).toList(); + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().startsWith("tool_call")) + .toList(); assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); @@ -295,11 +278,8 @@ public void testInvokeAgentSpanLinksToInvocation() { invocationSpan.end(); } - List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans: invocation and invoke_agent", 2, spans.size()); - - SpanData invocationSpanData = findSpanByName(spans, "invocation"); - SpanData invokeAgentSpanData = findSpanByName(spans, "invoke_agent test-agent"); + SpanData invocationSpanData = findSpanByName("invocation"); + SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); assertEquals( "Agent run should be child of invocation", @@ -325,11 +305,8 @@ public void testCallLlmSpanLinksToAgentRun() { invokeAgentSpan.end(); } - List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans: invoke_agent and call_llm", 2, spans.size()); - - SpanData invokeAgentSpanData = findSpanByName(spans, "invoke_agent test-agent"); - SpanData callLlmSpanData = findSpanByName(spans, "call_llm"); + SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); + SpanData callLlmSpanData = findSpanByName("call_llm"); assertEquals( "Call LLM should be child of agent run", @@ -349,11 +326,8 @@ public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { parentSpan.end(); } - List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans", 2, spans.size()); - - SpanData parentSpanData = findSpanByName(spans, "invocation"); - SpanData agentSpanData = findSpanByName(spans, "invoke_agent"); + SpanData parentSpanData = findSpanByName("invocation"); + SpanData agentSpanData = findSpanByName("invoke_agent"); assertEquals( "Agent span should be a child of the invocation span", @@ -384,23 +358,14 @@ public void testTraceFlowable() throws InterruptedException { parentSpan.end(); } - List spans = openTelemetryRule.getSpans(); - assertEquals(2, spans.size()); - SpanData parentSpanData = findSpanByName(spans, "parent"); - SpanData flowableSpanData = findSpanByName(spans, "flowable"); + SpanData parentSpanData = findSpanByName("parent"); + SpanData flowableSpanData = findSpanByName("flowable"); assertEquals( parentSpanData.getSpanContext().getSpanId(), flowableSpanData.getParentSpanContext().getSpanId()); assertTrue(flowableSpanData.hasEnded()); } - private SpanData findSpanByName(List spans, String name) { - return spans.stream() - .filter(s -> s.getName().equals(name)) - .findFirst() - .orElseThrow(() -> new AssertionError("Span not found: " + name)); - } - @Test public void testTraceAgentInvocation() { Span span = tracer.spanBuilder("test").startSpan(); @@ -560,4 +525,125 @@ public void testTraceSendData() { assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.data")).contains("hello")); } + + // Agent that emits one event on a computation thread. + private static class TestAgent extends BaseAgent { + TestAgent() { + super("test-agent", "test-description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext context) { + return Flowable.just( + Event.builder().content(Content.fromParts(Part.fromText("test"))).build()) + .subscribeOn(Schedulers.computation()); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.just( + Event.builder().content(Content.fromParts(Part.fromText("test"))).build()) + .subscribeOn(Schedulers.computation()); + } + } + + @Test + public void baseAgentRunAsync_propagatesContext() throws InterruptedException { + BaseAgent agent = new TestAgent(); + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + agent + .runAsync( + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build()) + .test() + .await() + .assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parent = findSpanByName("parent"); + SpanData agentSpan = findSpanByName("invoke_agent test-agent"); + assertEquals(parent.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + } + + @Test + public void runnerRunAsync_propagatesContext() throws InterruptedException { + BaseAgent agent = new TestAgent(); + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Session session = + runner + .sessionService() + .createSession("test-app", "user-1", null, "session-1") + .blockingGet(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + runner + .runAsync(session.userId(), session.id(), newMessage, runConfig, null) + .test() + .await() + .assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parent = findSpanByName("parent"); + SpanData invocation = findSpanByName("invocation"); + SpanData agentSpan = findSpanByName("invoke_agent test-agent"); + assertEquals( + parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); + assertEquals( + invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + } + + @Test + public void runnerRunLive_propagatesContext() throws InterruptedException { + BaseAgent agent = new TestAgent(); + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Session session = Session.builder("session-1").userId("user-1").appName("test-app").build(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + liveRequestQueue.content(newMessage); + liveRequestQueue.close(); + runner.runLive(session, liveRequestQueue, runConfig).test().await().assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parent = findSpanByName("parent"); + SpanData invocation = findSpanByName("invocation"); + SpanData agentSpan = findSpanByName("invoke_agent test-agent"); + assertEquals( + parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); + assertEquals( + invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + } + + /** + * Finds a span by name, polling multiple times. + * + *

    This is necessary because spans might be created in separate threads, and we cannot always + * rely on `.await()` to ensure all spans are available immediately. + */ + private SpanData findSpanByName(String name) { + for (int i = 0; i < 15; i++) { + Optional span = + openTelemetryRule.getSpans().stream().filter(s -> s.getName().equals(name)).findFirst(); + if (span.isPresent()) { + return span.get(); + } + try { + Thread.sleep(10 * i); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + throw new AssertionError("Span not found after polling: " + name); + } } From 8a1fffa52b1c34ad1b9465a7575c14609619ef64 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Feb 2026 06:27:41 -0800 Subject: [PATCH 48/63] refactor: Code clean up in BaseAgent 1. **Extracting Common Run Logic:** A new private method `run` is introduced to encapsulate the common setup and execution flow shared by both `runAsync` and `runLive`. This includes creating the `InvocationContext`, setting up tracing, and handling before/after callbacks. 2. **Reducing Duplication:** Both `runAsync` and `runLive` now call this new `run` method, passing their specific implementation (`runAsyncImpl` or `runLiveImpl`) as a function. This eliminates a significant amount of duplicated code. 3. **Minor Cleanups:** * The `createInvocationContext` method now uses `parentContext.branch().filter(...).ifPresent(...)` for a more concise optional handling. * An unnecessary `agentCallbacks == null` check is removed from `callCallback`, as the list should not be null. PiperOrigin-RevId: 868646905 --- .../java/com/google/adk/agents/BaseAgent.java | 64 ++++++------------- 1 file changed, 20 insertions(+), 44 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 20d7dfa4f..72fc5883a 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -288,9 +288,10 @@ private InvocationContext createInvocationContext(InvocationContext parentContex InvocationContext.Builder builder = parentContext.toBuilder(); builder.agent(this); // Check for branch to be truthy (not None, not empty string), - if (parentContext.branch().filter(s -> !s.isEmpty()).isPresent()) { - builder.branch(parentContext.branch().get() + "." + name()); - } + parentContext + .branch() + .filter(s -> !s.isEmpty()) + .ifPresent(branch -> builder.branch(branch + "." + name())); return builder.build(); } @@ -301,6 +302,19 @@ private InvocationContext createInvocationContext(InvocationContext parentContex * @return stream of agent-generated events. */ public Flowable runAsync(InvocationContext parentContext) { + return run(parentContext, this::runAsyncImpl); + } + + /** + * Runs the agent with the given implementation. + * + * @param parentContext Parent context to inherit. + * @param runImplementation The agent-specific logic to run. + * @return stream of agent-generated events. + */ + private Flowable run( + InvocationContext parentContext, + Function> runImplementation) { Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { @@ -326,7 +340,7 @@ public Flowable runAsync(InvocationContext parentContext) { Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); Flowable mainEvents = - Flowable.defer(() -> runAsyncImpl(invocationContext)); + Flowable.defer(() -> runImplementation.apply(invocationContext)); Flowable afterEvents = Flowable.defer( () -> @@ -382,7 +396,7 @@ private ImmutableList>> afterCallbacksT private Single> callCallback( List>> agentCallbacks, InvocationContext invocationContext) { - if (agentCallbacks == null || agentCallbacks.isEmpty()) { + if (agentCallbacks.isEmpty()) { return Single.just(Optional.empty()); } @@ -437,45 +451,7 @@ private Single> callCallback( * @return stream of agent-generated events. */ public Flowable runLive(InvocationContext parentContext) { - Tracer tracer = Tracing.getTracer(); - return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - Span span = - tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); - Tracing.traceAgentInvocation(span, name(), description(), invocationContext); - Context spanContext = Context.current().with(span); - - return Tracing.traceFlowable( - spanContext, - span, - () -> - callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runLiveImpl(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), - afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - })); - }); + return run(parentContext, this::runLiveImpl); } /** From 0502c2141724a238bbf5f7a72e1951cbb401a3e8 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Feb 2026 08:22:43 -0800 Subject: [PATCH 49/63] feat: Adding validation to InvocationContext 'session_service', 'invocation_id', 'agent', and 'session' fields PiperOrigin-RevId: 868688936 --- .../google/adk/agents/InvocationContext.java | 25 +++++- .../adk/agents/InvocationContextTest.java | 80 +++++++++++++++++-- .../adk/flows/llmflows/ContentsTest.java | 6 +- .../adk/flows/llmflows/ExamplesTest.java | 8 ++ ...stConfirmationLlmRequestProcessorTest.java | 13 +-- .../adk/telemetry/ContextPropagationTest.java | 66 +++++++-------- .../com/google/adk/tools/AgentToolTest.java | 4 + .../com/google/adk/tools/ExampleToolTest.java | 13 +-- .../google/adk/tools/FunctionToolTest.java | 47 +++-------- .../retrieval/VertexAiRagRetrievalTest.java | 31 ++++++- 10 files changed, 201 insertions(+), 92 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 2a0cc90c9..afee5065a 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -16,6 +16,8 @@ package com.google.adk.agents; +import static com.google.common.base.Strings.isNullOrEmpty; + import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; @@ -768,12 +770,33 @@ public Builder callbackContextData(Map callbackContextData) { * * @throws IllegalStateException if any required parameters are missing. */ - // TODO: b/462183912 - Add validation for required parameters. public InvocationContext build() { + validate(this); return new InvocationContext(this); } } + /** + * Validates the required parameters fields: invocationId, agent, session, and sessionService. + * + * @param builder the builder to validate. + * @throws IllegalStateException if any required parameters are missing. + */ + private static void validate(Builder builder) { + if (isNullOrEmpty(builder.invocationId)) { + throw new IllegalStateException("Invocation ID must be non-empty."); + } + if (builder.agent == null) { + throw new IllegalStateException("Agent must be set."); + } + if (builder.session == null) { + throw new IllegalStateException("Session must be set."); + } + if (builder.sessionService == null) { + throw new IllegalStateException("Session service must be set."); + } + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index c1cb30180..bbfbb74bd 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; import com.google.adk.apps.ResumabilityConfig; @@ -764,7 +765,7 @@ public void testSetEndInvocation() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testBranch() { InvocationContext context = InvocationContext.builder() @@ -785,7 +786,7 @@ public void testBranch() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testDeprecatedCreateMethods() { InvocationContext context1 = InvocationContext.builder() @@ -855,7 +856,7 @@ public void testEventsCompactionConfig() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testBuilderOptionalParameters() { InvocationContext context = InvocationContext.builder() @@ -874,7 +875,7 @@ public void testBuilderOptionalParameters() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testDeprecatedConstructor() { InvocationContext context = new InvocationContext( @@ -906,7 +907,7 @@ public void testDeprecatedConstructor() { } @Test - @SuppressWarnings("deprecation") // Testing deprecated methods. + // Testing deprecated methods. public void testDeprecatedConstructor_11params() { InvocationContext context = new InvocationContext( @@ -986,4 +987,73 @@ public void populateAgentStates_populatesAgentStatesAndEndOfAgents() { assertThat(context.endOfAgents()).hasSize(1); assertThat(context.endOfAgents()).containsEntry("agent1", true); } + + @Test + public void build_missingInvocationId_null_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .invocationId(null) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty."); + } + + @Test + public void build_missingInvocationId_empty_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .invocationId("") + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty."); + } + + @Test + public void build_missingAgent_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Agent must be set."); + } + + @Test + public void build_missingSession_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Session must be set."); + } + + @Test + public void build_missingSessionService_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Session service must be set."); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index a8a862b51..ea8571a0b 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -28,6 +28,7 @@ import com.google.adk.events.EventCompaction; import com.google.adk.models.LlmRequest; import com.google.adk.models.Model; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -53,7 +54,8 @@ public final class ContentsTest { private static final String AGENT = "agent"; private static final String OTHER_AGENT = "other_agent"; - private final Contents contentsProcessor = new Contents(); + private static final Contents contentsProcessor = new Contents(); + private static final InMemorySessionService sessionService = new InMemorySessionService(); @Test public void rearrangeLatest_emptyList_returnsEmptyList() { @@ -900,6 +902,7 @@ private List runContentsProcessorWithIncludeContents( .invocationId("test-invocation") .agent(agent) .session(session) + .sessionService(sessionService) .build(); LlmRequest initialRequest = LlmRequest.builder().build(); @@ -929,6 +932,7 @@ private List runContentsProcessorWithModelName(List events, Stri .invocationId("test-invocation") .agent(agent) .session(session) + .sessionService(sessionService) .build(); LlmRequest initialRequest = LlmRequest.builder().build(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java index 4a6216029..7d1615dc2 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java @@ -24,6 +24,8 @@ import com.google.adk.examples.BaseExampleProvider; import com.google.adk.examples.Example; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; @@ -35,6 +37,8 @@ @RunWith(JUnit4.class) public final class ExamplesTest { + private static final InMemorySessionService sessionService = new InMemorySessionService(); + private static class TestExampleProvider implements BaseExampleProvider { @Override public List getExamples(String query) { @@ -55,6 +59,8 @@ public void processRequest_withExampleProvider_addsExamplesToInstructions() { InvocationContext context = InvocationContext.builder() .invocationId("invocation1") + .session(Session.builder("session1").build()) + .sessionService(sessionService) .agent(agent) .userContent(Content.fromParts(Part.fromText("what is up?"))) .runConfig(RunConfig.builder().build()) @@ -76,6 +82,8 @@ public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructio InvocationContext context = InvocationContext.builder() .invocationId("invocation1") + .session(Session.builder("session1").build()) + .sessionService(sessionService) .agent(agent) .userContent(Content.fromParts(Part.fromText("what is up?"))) .runConfig(RunConfig.builder().build()) diff --git a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java index 26d456b70..55adeb39e 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java @@ -27,6 +27,7 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.plugins.PluginManager; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; @@ -60,6 +61,7 @@ public class RequestConfirmationLlmRequestProcessorTest { Optional.of(ORIGINAL_FUNCTION_CALL_ARGS))); private static final FunctionCall FUNCTION_CALL = FunctionCall.builder().id(FUNCTION_CALL_ID).name(ECHO_TOOL_NAME).args(ARGS).build(); + private static final InMemorySessionService sessionService = new InMemorySessionService(); private static final Event REQUEST_CONFIRMATION_EVENT = Event.builder() @@ -93,7 +95,7 @@ public void runAsync_withConfirmation_callsOriginalFunction() { .events(ImmutableList.of(REQUEST_CONFIRMATION_EVENT, USER_CONFIRMATION_EVENT)) .build(); - InvocationContext context = createInvocationContext(agent, session); + InvocationContext context = buildInvocationContext(agent, session); RequestProcessor.RequestProcessingResult result = processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); @@ -132,7 +134,7 @@ public void runAsync_withConfirmationAndToolAlreadyCalled_doesNotCallOriginalFun REQUEST_CONFIRMATION_EVENT, USER_CONFIRMATION_EVENT, toolResponseEvent)) .build(); - InvocationContext context = createInvocationContext(agent, session); + InvocationContext context = buildInvocationContext(agent, session); RequestProcessor.RequestProcessingResult result = processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); @@ -149,7 +151,7 @@ public void runAsync_noEvents_empty() { assertThat( processor .processRequest( - createInvocationContext(agent, session), LlmRequest.builder().build()) + buildInvocationContext(agent, session), LlmRequest.builder().build()) .blockingGet() .events()) .isEmpty(); @@ -164,18 +166,19 @@ public void runAsync_noUserConfirmationEvent_empty() { assertThat( processor .processRequest( - createInvocationContext(agent, session), LlmRequest.builder().build()) + buildInvocationContext(agent, session), LlmRequest.builder().build()) .blockingGet() .events()) .isEmpty(); } - private static InvocationContext createInvocationContext(LlmAgent agent, Session session) { + private static InvocationContext buildInvocationContext(LlmAgent agent, Session session) { return InvocationContext.builder() .pluginManager(new PluginManager()) .invocationId(InvocationContext.newInvocationContextId()) .agent(agent) .session(session) + .sessionService(sessionService) .build(); } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 3588a5b53..ece1bdad1 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -23,11 +23,13 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; +import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.runner.Runner; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -68,6 +70,8 @@ public class ContextPropagationTest { private Tracer tracer; private Tracer originalTracer; + private LlmAgent agent; + private InMemorySessionService sessionService; @Before public void setup() { @@ -75,6 +79,8 @@ public void setup() { Tracing.setTracerForTesting( openTelemetryRule.getOpenTelemetry().getTracer("ContextPropagationTest")); tracer = openTelemetryRule.getOpenTelemetry().getTracer("test"); + agent = LlmAgent.builder().name("test_agent").description("test-description").build(); + sessionService = new InMemorySessionService(); } @After @@ -371,13 +377,7 @@ public void testTraceAgentInvocation() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceAgentInvocation( - span, - "test-agent", - "test-description", - InvocationContext.builder() - .invocationId("inv-1") - .session(Session.builder("session-1").build()) - .build()); + span, "test-agent", "test-description", buildInvocationContext()); } finally { span.end(); } @@ -388,7 +388,7 @@ public void testTraceAgentInvocation() { assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("test-agent", attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))); assertEquals("test-description", attrs.get(AttributeKey.stringKey("gen_ai.agent.description"))); - assertEquals("session-1", attrs.get(AttributeKey.stringKey("gen_ai.conversation.id"))); + assertEquals("test-session", attrs.get(AttributeKey.stringKey("gen_ai.conversation.id"))); } @Test @@ -470,14 +470,7 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm( - InvocationContext.builder() - .invocationId("inv-1") - .session(Session.builder("session-1").build()) - .build(), - "event-1", - llmRequest, - llmResponse); + Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse); } finally { span.end(); } @@ -487,9 +480,10 @@ public void testTraceCallLlm() { Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); - assertEquals("inv-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals( + "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); - assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertEquals("test-session", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); assertEquals(0.9d, attrs.get(AttributeKey.doubleKey("gen_ai.request.top_p")), 0.01); assertEquals(100L, (long) attrs.get(AttributeKey.longKey("gen_ai.request.max_tokens"))); assertEquals(10L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.input_tokens"))); @@ -507,10 +501,7 @@ public void testTraceSendData() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceSendData( - InvocationContext.builder() - .invocationId("inv-1") - .session(Session.builder("session-1").build()) - .build(), + buildInvocationContext(), "event-1", ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); } finally { @@ -520,9 +511,10 @@ public void testTraceSendData() { assertEquals(1, spans.size()); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); - assertEquals("inv-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals( + "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); - assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertEquals("test-session", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.data")).contains("hello")); } @@ -552,15 +544,7 @@ public void baseAgentRunAsync_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { - agent - .runAsync( - InvocationContext.builder() - .invocationId("inv-1") - .session(Session.builder("session-1").build()) - .build()) - .test() - .await() - .assertComplete(); + agent.runAsync(buildInvocationContext()).test().await().assertComplete(); } finally { parentSpan.end(); } @@ -578,7 +562,7 @@ public void runnerRunAsync_propagatesContext() throws InterruptedException { Session session = runner .sessionService() - .createSession("test-app", "user-1", null, "session-1") + .createSession("test-app", "test-user", null, "test-session") .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); @@ -605,7 +589,8 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { Runner runner = Runner.builder().agent(agent).appName("test-app").build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { - Session session = Session.builder("session-1").userId("user-1").appName("test-app").build(); + Session session = + Session.builder("test-session").userId("test-user").appName("test-app").build(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); @@ -646,4 +631,15 @@ private SpanData findSpanByName(String name) { } throw new AssertionError("Span not found after polling: " + name); } + + private InvocationContext buildInvocationContext() { + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + return InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(agent) + .invocationId("test-invocation-id") + .build(); + } } diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index 2fc17b94d..94d36f5ec 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -28,6 +28,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.SequentialAgent; import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.utils.ComponentRegistry; @@ -48,6 +49,8 @@ @RunWith(JUnit4.class) public final class AgentToolTest { + private static final InMemorySessionService sessionService = new InMemorySessionService(); + @Test public void fromConfig_withRegisteredAgent_returnsAgentTool() throws Exception { LlmAgent testAgent = @@ -661,6 +664,7 @@ private static ToolContext createToolContext(BaseAgent agent) { .invocationId(InvocationContext.newInvocationContextId()) .agent(agent) .session(Session.builder("123").build()) + .sessionService(sessionService) .build()) .build(); } diff --git a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java index e4d92f197..4e80ed0ff 100644 --- a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java +++ b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java @@ -41,7 +41,7 @@ public final class ExampleToolTest { /** Helper to create a minimal agent & context for testing. */ - private InvocationContext newContext() { + private InvocationContext buildInvocationContext() { TestLlm testLlm = new TestLlm(() -> Flowable.just(LlmResponse.builder().build())); LlmAgent agent = TestUtils.createTestAgent(testLlm); return TestUtils.createInvocationContext(agent); @@ -58,7 +58,7 @@ private static Example makeExample(String in, String out) { public void processLlmRequest_withInlineExamples_appendsFewShot() { ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); @@ -75,7 +75,7 @@ public void processLlmRequest_withInlineExamples_appendsFewShot() { public void processLlmRequest_withProvider_appendsFewShot() { ExampleTool tool = ExampleTool.builder().setExampleProvider(ProviderHolder.EXAMPLES).build(); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); @@ -91,12 +91,13 @@ public void processLlmRequest_withProvider_appendsFewShot() { @Test public void processLlmRequest_withEmptyUserContent_doesNotAppendFewShot() { ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); - InvocationContext ctxWithContent = newContext(); + InvocationContext ctxWithContent = buildInvocationContext(); InvocationContext ctx = InvocationContext.builder() .invocationId(ctxWithContent.invocationId()) .agent(ctxWithContent.agent()) .session(ctxWithContent.session()) + .sessionService(ctxWithContent.sessionService()) .userContent(Content.fromParts(Part.fromText(""))) .runConfig(ctxWithContent.runConfig()) .build(); @@ -120,7 +121,7 @@ public void fromConfig_withInlineExamples_buildsTool() throws Exception { "output", ImmutableList.of(Content.fromParts(Part.fromText("a")))))); ExampleTool tool = ExampleTool.fromConfig(args); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); @@ -144,7 +145,7 @@ public void fromConfig_withProviderReference_buildsTool() throws Exception { "examples", ExampleToolTest.ProviderHolder.class.getName() + ".EXAMPLES"); ExampleTool tool = ExampleTool.fromConfig(args); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 5816c427a..ec2ca1f76 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -20,7 +20,9 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.events.ToolConfirmation; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -41,6 +43,16 @@ /** Unit tests for {@link FunctionTool}. */ @RunWith(JUnit4.class) public final class FunctionToolTest { + private static final ToolContext toolContext = + ToolContext.builder( + InvocationContext.builder() + .agent(LlmAgent.builder().name("test-agent").build()) + .session(Session.builder("123").build()) + .sessionService(new InMemorySessionService()) + .invocationId("invocation-id") + .build()) + .functionCallId("functionCallId") + .build(); @Test public void create_withNonSerializableParameter_raisesIllegalArgumentException() { @@ -233,11 +245,6 @@ public void create_withParameterizedList() { @Test public void call_withAllSupportedParameterTypes() throws Exception { FunctionTool tool = FunctionTool.create(Functions.class, "returnAllSupportedParametersAsMap"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync( @@ -579,11 +586,6 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception { Functions functions = new Functions(); FunctionTool tool = FunctionTool.create(functions, "nonStaticReturnAllSupportedParametersAsMap"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync( @@ -630,11 +632,6 @@ public void runAsync_withRequireConfirmation() throws Exception { Method method = Functions.class.getMethod("returnsMap"); FunctionTool tool = new FunctionTool(null, method, /* isLongRunning= */ false, /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); // First call, should request confirmation Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); @@ -663,11 +660,6 @@ public void create_instanceMethodWithConfirmation_requestsConfirmation() throws Functions functions = new Functions(); Method method = Functions.class.getMethod("nonStaticVoidReturnWithoutSchema"); FunctionTool tool = FunctionTool.create(functions, method, /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) @@ -680,11 +672,6 @@ public void create_instanceMethodWithConfirmation_requestsConfirmation() throws public void create_staticMethodWithConfirmation_requestsConfirmation() throws Exception { Method method = Functions.class.getMethod("voidReturnWithoutSchema"); FunctionTool tool = FunctionTool.create(method, /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) @@ -698,11 +685,6 @@ public void create_classMethodNameWithConfirmation_requestsConfirmation() throws FunctionTool tool = FunctionTool.create( Functions.class, "voidReturnWithoutSchema", /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) @@ -717,11 +699,6 @@ public void create_instanceMethodNameWithConfirmation_requestsConfirmation() thr FunctionTool tool = FunctionTool.create( functions, "nonStaticVoidReturnWithoutSchema", /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java index c501d6a43..77cae3b40 100644 --- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java +++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java @@ -6,7 +6,9 @@ import static org.mockito.Mockito.when; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.tools.ToolContext; import com.google.cloud.aiplatform.v1.RagContexts; @@ -38,9 +40,11 @@ public final class VertexAiRagRetrievalTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); @Mock private VertexRagServiceClient vertexRagServiceClient; + @Mock private BaseSessionService sessionService; @Test public void runAsync_withResults_returnsContexts() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -55,7 +59,11 @@ public void runAsync_withResults_returnsContexts() throws Exception { String query = "test query"; ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .sessionService(sessionService) + .build()) .functionCallId("functionCallId") .build(); RetrieveContextsRequest expectedRequest = @@ -85,6 +93,7 @@ public void runAsync_withResults_returnsContexts() throws Exception { @Test public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -99,7 +108,11 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { String query = "test query"; ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .sessionService(sessionService) + .build()) .functionCallId("functionCallId") .build(); RetrieveContextsRequest expectedRequest = @@ -129,6 +142,7 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { @Test public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); // This test's behavior depends on the GOOGLE_GENAI_USE_VERTEXAI environment variable boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI")); ImmutableList ragResources = @@ -145,7 +159,11 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-2-pro"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .sessionService(sessionService) + .build()) .functionCallId("functionCallId") .build(); @@ -197,6 +215,7 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { @Test public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { + LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -211,7 +230,11 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) + InvocationContext.builder() + .agent(agent) + .session(Session.builder("123").build()) + .sessionService(sessionService) + .build()) .functionCallId("functionCallId") .build(); GenerateContentConfig initialConfig = GenerateContentConfig.builder().build(); From ddb00efc1a1f531448b9f4dae28d647c6ffdf420 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Feb 2026 10:18:45 -0800 Subject: [PATCH 50/63] feat: Extend google_search support to Gemini 3 in Java ADK PiperOrigin-RevId: 868735240 --- .../main/java/com/google/adk/tools/GoogleSearchTool.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java b/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java index ffd9601a8..6f89754cf 100644 --- a/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java +++ b/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java @@ -28,8 +28,8 @@ import org.slf4j.LoggerFactory; /** - * A built-in tool that is automatically invoked by Gemini 2 models to retrieve search results from - * Google Search. + * A built-in tool that is automatically invoked by Gemini 2 and 3 models to retrieve search results + * from Google Search. * *

    This tool operates internally within the model and does not require or perform local code * execution. @@ -76,7 +76,7 @@ public Completable processLlmRequest( updatedToolsBuilder.add( Tool.builder().googleSearchRetrieval(GoogleSearchRetrieval.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); - } else if (model != null && model.startsWith("gemini-2")) { + } else if (model != null && (model.startsWith("gemini-2") || model.startsWith("gemini-3"))) { updatedToolsBuilder.add(Tool.builder().googleSearch(GoogleSearch.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); From a30bdf92b80e3c2555d49bc14295aa1ac14d6f36 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 11 Feb 2026 18:54:52 -0800 Subject: [PATCH 51/63] test: Update some tests that use InvocationContext PiperOrigin-RevId: 868944170 --- .../adk/flows/llmflows/ContentsTest.java | 15 ++-- .../com/google/adk/tools/AgentToolTest.java | 14 +++- .../google/adk/tools/FunctionToolTest.java | 30 +++++--- .../retrieval/VertexAiRagRetrievalTest.java | 70 ++++++++----------- 4 files changed, 63 insertions(+), 66 deletions(-) diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index ea8571a0b..b5df658ba 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,7 +36,6 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -892,11 +891,8 @@ private List runContentsProcessorWithIncludeContents( List events, LlmAgent.IncludeContents includeContents) { LlmAgent agent = LlmAgent.builder().name(AGENT).includeContents(includeContents).build(); Session session = - Session.builder("test-session") - .appName("test-app") - .userId("test-user") - .events(new ArrayList<>(events)) - .build(); + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + session.events().addAll(events); InvocationContext context = InvocationContext.builder() .invocationId("test-invocation") @@ -922,11 +918,8 @@ private List runContentsProcessorWithModelName(List events, Stri Mockito.doReturn(model).when(agent).resolvedModel(); Session session = - Session.builder("test-session") - .appName("test-app") - .userId("test-user") - .events(new ArrayList<>(events)) - .build(); + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + session.events().addAll(events); InvocationContext context = InvocationContext.builder() .invocationId("test-invocation") diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index 94d36f5ec..3a5390027 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -41,6 +41,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.Map; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,7 +50,12 @@ @RunWith(JUnit4.class) public final class AgentToolTest { - private static final InMemorySessionService sessionService = new InMemorySessionService(); + private InMemorySessionService sessionService; + + @Before + public void setUp() { + sessionService = new InMemorySessionService(); + } @Test public void fromConfig_withRegisteredAgent_returnsAgentTool() throws Exception { @@ -658,12 +664,14 @@ public void declaration_emptySequentialAgent_fallsBackToRequest() { .build()); } - private static ToolContext createToolContext(BaseAgent agent) { + private ToolContext createToolContext(BaseAgent agent) { + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); return ToolContext.builder( InvocationContext.builder() .invocationId(InvocationContext.newInvocationContextId()) .agent(agent) - .session(Session.builder("123").build()) + .session(session) .sessionService(sessionService) .build()) .build(); diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index ec2ca1f76..0939c6506 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -36,6 +36,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -43,16 +44,25 @@ /** Unit tests for {@link FunctionTool}. */ @RunWith(JUnit4.class) public final class FunctionToolTest { - private static final ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder() - .agent(LlmAgent.builder().name("test-agent").build()) - .session(Session.builder("123").build()) - .sessionService(new InMemorySessionService()) - .invocationId("invocation-id") - .build()) - .functionCallId("functionCallId") - .build(); + private LlmAgent agent; + private InMemorySessionService sessionService; + private ToolContext toolContext; + + @Before + public void setUp() { + agent = LlmAgent.builder().name("test-agent").build(); + sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .agent(agent) + .session(session) + .sessionService(sessionService) + .invocationId("invocation-id") + .build(); + toolContext = ToolContext.builder(invocationContext).functionCallId("functionCallId").build(); + } @Test public void create_withNonSerializableParameter_raisesIllegalArgumentException() { diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java index 77cae3b40..6f04a7ef8 100644 --- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java +++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java @@ -8,7 +8,7 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.LlmRequest; -import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.tools.ToolContext; import com.google.cloud.aiplatform.v1.RagContexts; @@ -27,6 +27,7 @@ import com.google.genai.types.VertexRagStore; import com.google.genai.types.VertexRagStoreRagResource; import java.util.Map; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -40,11 +41,18 @@ public final class VertexAiRagRetrievalTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); @Mock private VertexRagServiceClient vertexRagServiceClient; - @Mock private BaseSessionService sessionService; + + private InMemorySessionService sessionService; + private LlmAgent agent; + + @Before + public void setUp() { + sessionService = new InMemorySessionService(); + agent = LlmAgent.builder().name("test-agent").build(); + } @Test public void runAsync_withResults_returnsContexts() throws Exception { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -57,15 +65,7 @@ public void runAsync_withResults_returnsContexts() throws Exception { ragResources, vectorDistanceThreshold); String query = "test query"; - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder() - .agent(agent) - .session(Session.builder("123").build()) - .sessionService(sessionService) - .build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); RetrieveContextsRequest expectedRequest = RetrieveContextsRequest.newBuilder() .setParent("projects/test-project/locations/us-central1") @@ -93,7 +93,6 @@ public void runAsync_withResults_returnsContexts() throws Exception { @Test public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -106,15 +105,7 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { ragResources, vectorDistanceThreshold); String query = "test query"; - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder() - .agent(agent) - .session(Session.builder("123").build()) - .sessionService(sessionService) - .build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); RetrieveContextsRequest expectedRequest = RetrieveContextsRequest.newBuilder() .setParent("projects/test-project/locations/us-central1") @@ -142,7 +133,6 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { @Test public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); // This test's behavior depends on the GOOGLE_GENAI_USE_VERTEXAI environment variable boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI")); ImmutableList ragResources = @@ -157,15 +147,7 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { ragResources, vectorDistanceThreshold); LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-2-pro"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder() - .agent(agent) - .session(Session.builder("123").build()) - .sessionService(sessionService) - .build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); tool.processLlmRequest(llmRequestBuilder, toolContext).blockingAwait(); @@ -215,7 +197,6 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { @Test public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); ImmutableList ragResources = ImmutableList.of(RagResource.newBuilder().setRagCorpus("corpus1").build()); Double vectorDistanceThreshold = 0.5; @@ -228,15 +209,7 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { ragResources, vectorDistanceThreshold); LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder() - .agent(agent) - .session(Session.builder("123").build()) - .sessionService(sessionService) - .build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); GenerateContentConfig initialConfig = GenerateContentConfig.builder().build(); llmRequestBuilder.config(initialConfig); @@ -264,4 +237,17 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { .build())) .build()); } + + private ToolContext buildToolContext() { + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + return ToolContext.builder( + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(agent) + .session(session) + .sessionService(sessionService) + .build()) + .build(); + } } From e989ae1337a84fd6686504050d2a3bf2db15c32c Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 12 Feb 2026 06:53:42 -0800 Subject: [PATCH 52/63] fix: pass mutable function args map to beforeToolCallback PiperOrigin-RevId: 869194554 --- .../google/adk/flows/llmflows/Functions.java | 9 +- .../com/google/adk/agents/CallbacksTest.java | 119 ++++++++++++++++++ 2 files changed, 126 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 75289bd38..26f14d24b 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -472,8 +472,12 @@ private static Maybe> maybeInvokeBeforeToolCall( if (invocationContext.agent() instanceof LlmAgent) { LlmAgent agent = (LlmAgent) invocationContext.agent(); + HashMap mutableFunctionArgs = new HashMap<>(functionArgs); + Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + invocationContext + .pluginManager() + .beforeToolCallback(tool, mutableFunctionArgs, toolContext); List callbacks = agent.canonicalBeforeToolCallbacks(); if (callbacks.isEmpty()) { @@ -486,7 +490,8 @@ private static Maybe> maybeInvokeBeforeToolCall( Flowable.fromIterable(callbacks) .concatMapMaybe( callback -> - callback.call(invocationContext, tool, functionArgs, toolContext)) + callback.call( + invocationContext, tool, mutableFunctionArgs, toolContext)) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); diff --git a/core/src/test/java/com/google/adk/agents/CallbacksTest.java b/core/src/test/java/com/google/adk/agents/CallbacksTest.java index f5ba8aafd..11087e6d6 100644 --- a/core/src/test/java/com/google/adk/agents/CallbacksTest.java +++ b/core/src/test/java/com/google/adk/agents/CallbacksTest.java @@ -1057,6 +1057,125 @@ public void handleFunctionCalls_withChainedToolCallbacks_overridesResultAndPasse assertThat(invocationContext.session().state()).containsExactlyEntriesIn(stateAddedByBc2); } + @Test + public void + handleFunctionCalls_withChainedBeforeToolCallbacks_firstModifiesArgsSecondReturnsResponse() { + ImmutableMap originalArgs = ImmutableMap.of("arg1", "val1"); + ImmutableMap modifiedArgsByCb1 = + ImmutableMap.of("arg1", "val1", "arg2", "val2"); + ImmutableMap responseFromCb2 = ImmutableMap.of("result", "from cb2"); + + Callbacks.BeforeToolCallbackSync cb1 = + (invocationContext, tool, input, toolContext) -> { + input.put("arg2", "val2"); + return Optional.empty(); + }; + + Callbacks.BeforeToolCallbackSync cb2 = + (invocationContext, tool, input, toolContext) -> { + if (input.equals(modifiedArgsByCb1)) { + return Optional.of(responseFromCb2); + } + return Optional.empty(); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .beforeToolCallback(ImmutableList.of(cb1, cb2)) + .build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("fc_id") + .name("echo_tool") + .args(originalArgs) + .build()) + .build())) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("echo_tool", new TestUtils.FailingEchoTool())) + .blockingGet(); + + assertThat(getFunctionResponse(functionResponseEvent)).isEqualTo(responseFromCb2); + } + + @Test + public void + handleFunctionCalls_withPluginAndAgentBeforeToolCallbacks_pluginModifiesArgsAgentSeesThem() { + ImmutableMap originalArgs = ImmutableMap.of("arg1", "val1"); + ImmutableMap modifiedArgsByPlugin = + ImmutableMap.of("arg1", "val1", "arg2", "val2"); + ImmutableMap responseFromAgentCb = ImmutableMap.of("result", "from agent cb"); + + Plugin testPlugin = + new Plugin() { + @Override + public String getName() { + return "test_plugin"; + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + toolArgs.put("arg2", "val2"); + return Maybe.empty(); + } + }; + + Callbacks.BeforeToolCallbackSync agentCb = + (invocationContext, tool, input, toolContext) -> { + if (input.equals(modifiedArgsByPlugin)) { + return Optional.of(responseFromAgentCb); + } + return Optional.empty(); + }; + + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .beforeToolCallbackSync(agentCb) + .build(); + + InvocationContext invocationContext = + createInvocationContext(agent).toBuilder() + .pluginManager(new PluginManager(ImmutableList.of(testPlugin))) + .build(); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("fc_id") + .name("echo_tool") + .args(originalArgs) + .build()) + .build())) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("echo_tool", new TestUtils.FailingEchoTool())) + .blockingGet(); + + assertThat(getFunctionResponse(functionResponseEvent)).isEqualTo(responseFromAgentCb); + } + @Test public void agentRunAsync_withToolCallbacks_inspectsArgsAndReturnsResponse() { TestUtils.EchoTool echoTool = new TestUtils.EchoTool(); From 8b887fda1d03c6d76a3f3354bdf84b5976b1650e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Feb 2026 13:08:42 -0800 Subject: [PATCH 53/63] refactor: Fixing EventAction's artifaceDelta type Map now matches with Vertex AI Session API and the python artifactDelta implementation. The new implementation unblocks rewind functionality. Background: - Artifact service stores artifacts with a key like `{app}_{user}_{sessionId}_{file}_{version}` - Session service has the app, user, session id. It saves EventAction with an economic `[file] : [version]` map. The Python ADK always had it this way, and somehow Java diverged. PiperOrigin-RevId: 869350765 --- .../google/adk/agents/CallbackContext.java | 2 +- .../com/google/adk/events/EventActions.java | 11 +++++----- .../adk/flows/llmflows/CodeExecution.java | 6 ++---- .../adk/sessions/SessionJsonConverter.java | 20 ++++++++++--------- .../google/adk/events/EventActionsTest.java | 6 +++--- .../java/com/google/adk/events/EventTest.java | 4 +--- .../sessions/SessionJsonConverterTest.java | 14 +++++-------- 7 files changed, 28 insertions(+), 35 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index 808c737a3..49298451b 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -134,7 +134,7 @@ public Completable saveArtifact(String filename, Part artifact) { invocationContext.session().id(), filename, artifact) - .doOnSuccess(unusedVersion -> this.eventActions.artifactDelta().put(filename, artifact)) + .doOnSuccess(version -> this.eventActions.artifactDelta().put(filename, version)) .ignoreElement(); } } diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 157deda63..07488a171 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -22,7 +22,6 @@ import com.google.adk.agents.BaseAgentState; import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.genai.types.Part; import java.util.HashSet; import java.util.Objects; import java.util.Optional; @@ -38,7 +37,7 @@ public class EventActions extends JsonBaseModel { private Optional skipSummarization; private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; + private ConcurrentMap artifactDelta; private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; @@ -117,11 +116,11 @@ public void removeStateByKey(String key) { } @JsonProperty("artifactDelta") - public ConcurrentMap artifactDelta() { + public ConcurrentMap artifactDelta() { return artifactDelta; } - public void setArtifactDelta(ConcurrentMap artifactDelta) { + public void setArtifactDelta(ConcurrentMap artifactDelta) { this.artifactDelta = artifactDelta; } @@ -294,7 +293,7 @@ public int hashCode() { public static class Builder { private Optional skipSummarization; private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; + private ConcurrentMap artifactDelta; private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; @@ -351,7 +350,7 @@ public Builder stateDelta(ConcurrentMap value) { @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(ConcurrentMap value) { + public Builder artifactDelta(ConcurrentMap value) { this.artifactDelta = value; return this; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index bb1789609..1f99cf4a2 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -448,11 +448,9 @@ private static Single postProcessCodeExecutionResult( .toList() .map( versions -> { - ConcurrentMap artifactDelta = new ConcurrentHashMap<>(); + ConcurrentMap artifactDelta = new ConcurrentHashMap<>(); for (int i = 0; i < versions.size(); i++) { - artifactDelta.put( - codeExecutionResult.outputFiles().get(i).name(), - Part.fromText(String.valueOf(versions.get(i)))); + artifactDelta.put(codeExecutionResult.outputFiles().get(i).name(), versions.get(i)); } eventActionsBuilder.artifactDelta(artifactDelta); return Event.builder() diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index f39605e3a..71b072695 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -29,7 +29,6 @@ import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.GroundingMetadata; -import com.google.genai.types.Part; import java.io.UncheckedIOException; import java.time.Instant; import java.util.Collection; @@ -109,7 +108,9 @@ static String convertEventToJson(Event event, boolean useIsoString) { actionsJson.put("transferAgent", v); }); actions.escalate().ifPresent(v -> actionsJson.put("escalate", v)); - actionsJson.put("endOfAgent", actions.endOfAgent()); + if (actions.endOfAgent()) { + actionsJson.put("endOfAgent", actions.endOfAgent()); + } putIfNotEmpty(actionsJson, "requestedAuthConfigs", actions.requestedAuthConfigs()); putIfNotEmpty( actionsJson, "requestedToolConfirmations", actions.requestedToolConfirmations()); @@ -297,18 +298,19 @@ private static Instant convertToInstant(Object timestampObj) { * @return A {@link ConcurrentMap} representing the artifact delta. */ @SuppressWarnings("unchecked") - private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { + private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { return new ConcurrentHashMap<>(); } - ConcurrentMap artifactDeltaMap = new ConcurrentHashMap<>(); - Map> rawMap = (Map>) artifactDeltaObj; - for (Map.Entry> entry : rawMap.entrySet()) { + ConcurrentMap artifactDeltaMap = new ConcurrentHashMap<>(); + Map rawMap = (Map) artifactDeltaObj; + for (Map.Entry entry : rawMap.entrySet()) { try { - Part part = objectMapper.convertValue(entry.getValue(), Part.class); - artifactDeltaMap.put(entry.getKey(), part); + Integer value = objectMapper.convertValue(entry.getValue(), Integer.class); + artifactDeltaMap.put(entry.getKey(), value); } catch (IllegalArgumentException e) { - logger.warn("Error converting artifactDelta value to Part for key: {}", entry.getKey(), e); + logger.warn( + "Error converting artifactDelta value to Integer for key: {}", entry.getKey(), e); } } return artifactDeltaMap; diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 7a58de575..94cd399df 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -63,7 +63,7 @@ public void merge_mergesAllFields() { EventActions.builder() .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", 1))) .deletedArtifactIds(ImmutableSet.of("deleted1")) .requestedAuthConfigs( new ConcurrentHashMap<>( @@ -75,7 +75,7 @@ public void merge_mergesAllFields() { EventActions eventActions2 = EventActions.builder() .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", 2))) .deletedArtifactIds(ImmutableSet.of("deleted2")) .transferToAgent("agentId") .escalate(true) @@ -91,7 +91,7 @@ public void merge_mergesAllFields() { assertThat(merged.skipSummarization()).hasValue(true); assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2"); - assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART); + assertThat(merged.artifactDelta()).containsExactly("artifact1", 1, "artifact2", 2); assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2"); assertThat(merged.transferToAgent()).hasValue("agentId"); assertThat(merged.escalate()).hasValue(true); diff --git a/core/src/test/java/com/google/adk/events/EventTest.java b/core/src/test/java/com/google/adk/events/EventTest.java index d6de97f7f..cbfb6ef0b 100644 --- a/core/src/test/java/com/google/adk/events/EventTest.java +++ b/core/src/test/java/com/google/adk/events/EventTest.java @@ -45,9 +45,7 @@ public final class EventTest { EventActions.builder() .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) - .artifactDelta( - new ConcurrentHashMap<>( - ImmutableMap.of("artifact_key", Part.builder().text("artifact_value").build()))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact_key", 1))) .transferToAgent("agent_id") .escalate(true) .requestedAuthConfigs( diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index f947a7140..f6120cf08 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -39,9 +39,7 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio EventActions.builder() .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) - .artifactDelta( - new ConcurrentHashMap<>( - ImmutableMap.of("artifact", Part.fromText("artifact_text")))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact", 1))) .transferToAgent("agent") .escalate(true) .build(); @@ -80,8 +78,7 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio JsonNode actionsNode = jsonNode.get("actions"); assertThat(actionsNode.get("skipSummarization").asBoolean()).isTrue(); assertThat(actionsNode.get("stateDelta").get("key").asText()).isEqualTo("value"); - assertThat(actionsNode.get("artifactDelta").get("artifact").get("text").asText()) - .isEqualTo("artifact_text"); + assertThat(actionsNode.get("artifactDelta").get("artifact").asInt()).isEqualTo(1); assertThat(actionsNode.get("transferAgent").asText()).isEqualTo("agent"); assertThat(actionsNode.get("escalate").asBoolean()).isTrue(); } @@ -131,8 +128,7 @@ public void fromApiEvent_fullEvent_success() { Map actions = new HashMap<>(); actions.put("skipSummarization", true); actions.put("stateDelta", ImmutableMap.of("key", "value")); - actions.put( - "artifactDelta", ImmutableMap.of("artifact", ImmutableMap.of("text", "artifact_text"))); + actions.put("artifactDelta", ImmutableMap.of("artifact", 1)); actions.put("transferAgent", "agent"); actions.put("escalate", true); apiEvent.put("actions", actions); @@ -154,7 +150,7 @@ public void fromApiEvent_fullEvent_success() { EventActions eventActions = event.actions(); assertThat(eventActions.skipSummarization()).hasValue(true); assertThat(eventActions.stateDelta()).containsEntry("key", "value"); - assertThat(eventActions.artifactDelta().get("artifact").text()).hasValue("artifact_text"); + assertThat(eventActions.artifactDelta()).containsEntry("artifact", 1); assertThat(eventActions.transferToAgent()).hasValue("agent"); assertThat(eventActions.escalate()).hasValue(true); } @@ -383,7 +379,7 @@ public void fromApiEvent_withInvalidArtifactDelta_skipsInvalidEntries() { apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); Map artifactDelta = new HashMap<>(); - artifactDelta.put("valid", ImmutableMap.of("text", "valid_text")); + artifactDelta.put("valid", 1); artifactDelta.put("invalid", "not-a-map"); Map actions = new HashMap<>(); From b89f2ebdbc32bf5deb7cbd298d9b4840023d080e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Feb 2026 14:25:03 -0800 Subject: [PATCH 54/63] refactor: Removing fields that were intended for a future feature The ADK Python code has pause/resume/rewind functionality. It's going to be some time before ADK Java gets it. For now, it's worth cleaning up the half-baked feature. PiperOrigin-RevId: 869384320 --- .../com/google/adk/agents/BaseAgentState.java | 39 --- .../google/adk/agents/InvocationContext.java | 123 +-------- .../main/java/com/google/adk/apps/App.java | 18 +- .../google/adk/apps/ResumabilityConfig.java | 28 --- .../com/google/adk/events/EventActions.java | 56 +---- .../java/com/google/adk/runner/Runner.java | 56 +---- .../adk/agents/InvocationContextTest.java | 238 +----------------- .../com/google/adk/runner/RunnerTest.java | 43 ---- 8 files changed, 11 insertions(+), 590 deletions(-) delete mode 100644 core/src/main/java/com/google/adk/agents/BaseAgentState.java delete mode 100644 core/src/main/java/com/google/adk/apps/ResumabilityConfig.java diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentState.java b/core/src/main/java/com/google/adk/agents/BaseAgentState.java deleted file mode 100644 index dedcb93ab..000000000 --- a/core/src/main/java/com/google/adk/agents/BaseAgentState.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.agents; - -import com.google.adk.JsonBaseModel; - -/** Base class for all agent states. */ -public class BaseAgentState extends JsonBaseModel { - - protected BaseAgentState() {} - - /** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */ - public static Builder builder() { - return new Builder(); - } - - /** Builder for {@link BaseAgentState}. */ - public static class Builder { - private Builder() {} - - public BaseAgentState build() { - return new BaseAgentState(); - } - } -} diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index afee5065a..6457a8ca4 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -18,9 +18,7 @@ import static com.google.common.base.Strings.isNullOrEmpty; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; -import com.google.adk.events.Event; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.Plugin; @@ -28,12 +26,9 @@ import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.summarizer.EventsCompactionConfig; -import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -54,9 +49,6 @@ public class InvocationContext { private final Session session; private final Optional userContent; private final RunConfig runConfig; - private final Map agentStates; - private final Map endOfAgents; - private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ContextCacheConfig contextCacheConfig; private final InvocationCostManager invocationCostManager; @@ -80,13 +72,10 @@ protected InvocationContext(Builder builder) { this.userContent = builder.userContent; this.runConfig = builder.runConfig; this.endInvocation = builder.endInvocation; - this.agentStates = builder.agentStates; - this.endOfAgents = builder.endOfAgents; - this.resumabilityConfig = builder.resumabilityConfig; this.eventsCompactionConfig = builder.eventsCompactionConfig; this.contextCacheConfig = builder.contextCacheConfig; this.invocationCostManager = builder.invocationCostManager; - this.callbackContextData = builder.callbackContextData; + this.callbackContextData = new ConcurrentHashMap<>(builder.callbackContextData); } /** @@ -267,10 +256,7 @@ public String invocationId() { /** * Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation * history. - * - * @deprecated Use {@link #toBuilder()} and {@link Builder#branch(String)} instead. */ - @Deprecated(forRemoval = true) public void branch(@Nullable String branch) { this.branch = Optional.ofNullable(branch); } @@ -321,16 +307,6 @@ public Map callbackContextData() { return callbackContextData; } - /** Returns agent-specific state saved within this invocation. */ - public Map agentStates() { - return agentStates; - } - - /** Returns map of agents that ended during this invocation. */ - public Map endOfAgents() { - return endOfAgents; - } - /** * Returns whether this invocation should be ended, e.g., due to reaching a terminal state or * error. @@ -369,36 +345,6 @@ public void incrementLlmCallsCount() throws LlmCallsLimitExceededException { this.invocationCostManager.incrementAndEnforceLlmCallsLimit(this.runConfig); } - /** Returns whether the current invocation is resumable. */ - public boolean isResumable() { - return resumabilityConfig.isResumable(); - } - - /** Returns ResumabilityConfig for this invocation. */ - public ResumabilityConfig resumabilityConfig() { - return resumabilityConfig; - } - - /** - * Populates agentStates and endOfAgents maps by reading session events for this invocation id. - */ - public void populateAgentStates(List events) { - events.stream() - .filter(event -> invocationId().equals(event.invocationId())) - .forEach( - event -> { - if (event.actions() != null) { - if (event.actions().agentState() != null - && !event.actions().agentState().isEmpty()) { - agentStates.putAll(event.actions().agentState()); - } - if (event.actions().endOfAgent()) { - endOfAgents.put(event.author(), true); - } - } - }); - } - /** Returns the events compaction configuration for the current agent run. */ public Optional eventsCompactionConfig() { return Optional.ofNullable(eventsCompactionConfig); @@ -409,23 +355,6 @@ public Optional contextCacheConfig() { return Optional.ofNullable(contextCacheConfig); } - /** Returns whether to pause the invocation right after this [event]. */ - public boolean shouldPauseInvocation(Event event) { - if (!isResumable()) { - return false; - } - - var longRunningToolIds = event.longRunningToolIds().orElse(ImmutableSet.of()); - if (longRunningToolIds.isEmpty()) { - return false; - } - - return event.functionCalls().stream() - .map(FunctionCall::id) - .flatMap(Optional::stream) - .anyMatch(functionCallId -> longRunningToolIds.contains(functionCallId)); - } - private static class InvocationCostManager { private int numberOfLlmCalls = 0; @@ -477,13 +406,10 @@ private Builder(InvocationContext context) { this.userContent = context.userContent; this.runConfig = context.runConfig; this.endInvocation = context.endInvocation; - this.agentStates = new ConcurrentHashMap<>(context.agentStates); - this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents); - this.resumabilityConfig = context.resumabilityConfig; this.eventsCompactionConfig = context.eventsCompactionConfig; this.contextCacheConfig = context.contextCacheConfig; this.invocationCostManager = context.invocationCostManager; - this.callbackContextData = context.callbackContextData; + this.callbackContextData = new ConcurrentHashMap<>(context.callbackContextData); } private BaseSessionService sessionService; @@ -499,9 +425,6 @@ private Builder(InvocationContext context) { private Optional userContent = Optional.empty(); private RunConfig runConfig = RunConfig.builder().build(); private boolean endInvocation = false; - private Map agentStates = new ConcurrentHashMap<>(); - private Map endOfAgents = new ConcurrentHashMap<>(); - private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); @Nullable private EventsCompactionConfig eventsCompactionConfig; @Nullable private ContextCacheConfig contextCacheConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); @@ -693,42 +616,6 @@ public Builder endInvocation(boolean endInvocation) { return this; } - /** - * Sets agent-specific state saved within this invocation. - * - * @param agentStates agent-specific state saved within this invocation. - * @return this builder instance for chaining. - */ - @CanIgnoreReturnValue - public Builder agentStates(Map agentStates) { - this.agentStates = agentStates; - return this; - } - - /** - * Sets agent end-of-invocation status. - * - * @param endOfAgents agent end-of-invocation status. - * @return this builder instance for chaining. - */ - @CanIgnoreReturnValue - public Builder endOfAgents(Map endOfAgents) { - this.endOfAgents = endOfAgents; - return this; - } - - /** - * Sets the resumability configuration for the current agent run. - * - * @param resumabilityConfig the resumability configuration. - * @return this builder instance for chaining. - */ - @CanIgnoreReturnValue - public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { - this.resumabilityConfig = resumabilityConfig; - return this; - } - /** * Sets the events compaction configuration for the current agent run. * @@ -818,9 +705,6 @@ public boolean equals(Object o) { && Objects.equals(session, that.session) && Objects.equals(userContent, that.userContent) && Objects.equals(runConfig, that.runConfig) - && Objects.equals(agentStates, that.agentStates) - && Objects.equals(endOfAgents, that.endOfAgents) - && Objects.equals(resumabilityConfig, that.resumabilityConfig) && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) && Objects.equals(contextCacheConfig, that.contextCacheConfig) && Objects.equals(invocationCostManager, that.invocationCostManager) @@ -843,9 +727,6 @@ public int hashCode() { userContent, runConfig, endInvocation, - agentStates, - endOfAgents, - resumabilityConfig, eventsCompactionConfig, contextCacheConfig, invocationCostManager, diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index d6635d2e7..18e8753c7 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -41,7 +41,6 @@ public class App { private final BaseAgent rootAgent; private final ImmutableList plugins; @Nullable private final EventsCompactionConfig eventsCompactionConfig; - @Nullable private final ResumabilityConfig resumabilityConfig; @Nullable private final ContextCacheConfig contextCacheConfig; private App( @@ -49,13 +48,11 @@ private App( BaseAgent rootAgent, List plugins, @Nullable EventsCompactionConfig eventsCompactionConfig, - @Nullable ResumabilityConfig resumabilityConfig, @Nullable ContextCacheConfig contextCacheConfig) { this.name = name; this.rootAgent = rootAgent; this.plugins = ImmutableList.copyOf(plugins); this.eventsCompactionConfig = eventsCompactionConfig; - this.resumabilityConfig = resumabilityConfig; this.contextCacheConfig = contextCacheConfig; } @@ -76,11 +73,6 @@ public EventsCompactionConfig eventsCompactionConfig() { return eventsCompactionConfig; } - @Nullable - public ResumabilityConfig resumabilityConfig() { - return resumabilityConfig; - } - @Nullable public ContextCacheConfig contextCacheConfig() { return contextCacheConfig; @@ -92,7 +84,6 @@ public static class Builder { private BaseAgent rootAgent; private List plugins = ImmutableList.of(); @Nullable private EventsCompactionConfig eventsCompactionConfig; - @Nullable private ResumabilityConfig resumabilityConfig; @Nullable private ContextCacheConfig contextCacheConfig; @CanIgnoreReturnValue @@ -119,12 +110,6 @@ public Builder eventsCompactionConfig(EventsCompactionConfig eventsCompactionCon return this; } - @CanIgnoreReturnValue - public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { - this.resumabilityConfig = resumabilityConfig; - return this; - } - @CanIgnoreReturnValue public Builder contextCacheConfig(ContextCacheConfig contextCacheConfig) { this.contextCacheConfig = contextCacheConfig; @@ -139,8 +124,7 @@ public App build() { throw new IllegalStateException("Root agent must be provided."); } validateAppName(name); - return new App( - name, rootAgent, plugins, eventsCompactionConfig, resumabilityConfig, contextCacheConfig); + return new App(name, rootAgent, plugins, eventsCompactionConfig, contextCacheConfig); } } diff --git a/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java b/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java deleted file mode 100644 index b80ce709c..000000000 --- a/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS language governing permissions and - * limitations under the License. - */ -package com.google.adk.apps; - -/** - * An app contains Resumability configuration for the agents. - * - * @param isResumable Whether the app is resumable. - */ -public record ResumabilityConfig(boolean isResumable) { - - /** Creates a new {@code ResumabilityConfig} with resumability disabled. */ - public ResumabilityConfig() { - this(false); - } -} diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 07488a171..6d8c698dd 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.adk.JsonBaseModel; -import com.google.adk.agents.BaseAgentState; import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; @@ -44,9 +43,7 @@ public class EventActions extends JsonBaseModel { private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; - private ConcurrentMap agentState; private Optional compaction; - private Optional rewindBeforeInvocationId; /** Default constructor for Jackson. */ public EventActions() { @@ -60,8 +57,6 @@ public EventActions() { this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endOfAgent = false; this.compaction = Optional.empty(); - this.agentState = new ConcurrentHashMap<>(); - this.rewindBeforeInvocationId = Optional.empty(); } private EventActions(Builder builder) { @@ -75,8 +70,6 @@ private EventActions(Builder builder) { this.requestedToolConfirmations = builder.requestedToolConfirmations; this.endOfAgent = builder.endOfAgent; this.compaction = builder.compaction; - this.agentState = builder.agentState; - this.rewindBeforeInvocationId = builder.rewindBeforeInvocationId; } @JsonProperty("skipSummarization") @@ -223,25 +216,6 @@ public void setCompaction(Optional compaction) { this.compaction = compaction; } - @JsonProperty("agentState") - @JsonInclude(JsonInclude.Include.NON_EMPTY) - public ConcurrentMap agentState() { - return agentState; - } - - public void setAgentState(ConcurrentMap agentState) { - this.agentState = agentState; - } - - @JsonProperty("rewindBeforeInvocationId") - public Optional rewindBeforeInvocationId() { - return rewindBeforeInvocationId; - } - - public void setRewindBeforeInvocationId(@Nullable String rewindBeforeInvocationId) { - this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); - } - public static Builder builder() { return new Builder(); } @@ -267,9 +241,7 @@ public boolean equals(Object o) { && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) && Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations) && (endOfAgent == that.endOfAgent) - && Objects.equals(compaction, that.compaction) - && Objects.equals(agentState, that.agentState) - && Objects.equals(rewindBeforeInvocationId, that.rewindBeforeInvocationId); + && Objects.equals(compaction, that.compaction); } @Override @@ -284,9 +256,7 @@ public int hashCode() { requestedAuthConfigs, requestedToolConfirmations, endOfAgent, - compaction, - agentState, - rewindBeforeInvocationId); + compaction); } /** Builder for {@link EventActions}. */ @@ -301,8 +271,6 @@ public static class Builder { private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; private Optional compaction; - private ConcurrentMap agentState; - private Optional rewindBeforeInvocationId; public Builder() { this.skipSummarization = Optional.empty(); @@ -314,8 +282,6 @@ public Builder() { this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.compaction = Optional.empty(); - this.agentState = new ConcurrentHashMap<>(); - this.rewindBeforeInvocationId = Optional.empty(); } private Builder(EventActions eventActions) { @@ -330,8 +296,6 @@ private Builder(EventActions eventActions) { new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); this.endOfAgent = eventActions.endOfAgent(); this.compaction = eventActions.compaction(); - this.agentState = new ConcurrentHashMap<>(eventActions.agentState()); - this.rewindBeforeInvocationId = eventActions.rewindBeforeInvocationId(); } @CanIgnoreReturnValue @@ -416,20 +380,6 @@ public Builder compaction(EventCompaction value) { return this; } - @CanIgnoreReturnValue - @JsonProperty("agentState") - public Builder agentState(ConcurrentMap agentState) { - this.agentState = agentState; - return this; - } - - @CanIgnoreReturnValue - @JsonProperty("rewindBeforeInvocationId") - public Builder rewindBeforeInvocationId(String rewindBeforeInvocationId) { - this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); - return this; - } - @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); @@ -442,8 +392,6 @@ public Builder merge(EventActions other) { this.requestedToolConfirmations.putAll(other.requestedToolConfirmations()); this.endOfAgent = other.endOfAgent(); other.compaction().ifPresent(this::compaction); - this.agentState.putAll(other.agentState()); - other.rewindBeforeInvocationId().ifPresent(this::rewindBeforeInvocationId); return this; } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 31026ee6e..0095d3fb6 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -24,7 +24,6 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; @@ -74,7 +73,6 @@ public class Runner { private final BaseSessionService sessionService; @Nullable private final BaseMemoryService memoryService; private final PluginManager pluginManager; - private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; @Nullable private final ContextCacheConfig contextCacheConfig; @@ -138,7 +136,6 @@ public Runner build() { BaseAgent buildAgent; String buildAppName; List buildPlugins; - ResumabilityConfig buildResumabilityConfig; EventsCompactionConfig buildEventsCompactionConfig; ContextCacheConfig buildContextCacheConfig; @@ -152,17 +149,12 @@ public Runner build() { buildAgent = this.app.rootAgent(); buildPlugins = this.app.plugins(); buildAppName = this.appName == null ? this.app.name() : this.appName; - buildResumabilityConfig = - this.app.resumabilityConfig() != null - ? this.app.resumabilityConfig() - : new ResumabilityConfig(); buildEventsCompactionConfig = this.app.eventsCompactionConfig(); buildContextCacheConfig = this.app.contextCacheConfig(); } else { buildAgent = this.agent; buildAppName = this.appName; buildPlugins = this.plugins; - buildResumabilityConfig = new ResumabilityConfig(); buildEventsCompactionConfig = null; buildContextCacheConfig = null; } @@ -186,7 +178,6 @@ public Runner build() { sessionService, memoryService, buildPlugins, - buildResumabilityConfig, buildEventsCompactionConfig, buildContextCacheConfig); } @@ -208,14 +199,7 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService) { - this( - agent, - appName, - artifactService, - sessionService, - memoryService, - ImmutableList.of(), - new ResumabilityConfig()); + this(agent, appName, artifactService, sessionService, memoryService, ImmutableList.of()); } /** @@ -231,40 +215,7 @@ public Runner( BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, List plugins) { - this( - agent, - appName, - artifactService, - sessionService, - memoryService, - plugins, - new ResumabilityConfig()); - } - - /** - * Creates a new {@code Runner} with a list of plugins and resumability config. - * - * @deprecated Use {@link Runner.Builder} instead. - */ - @Deprecated - public Runner( - BaseAgent agent, - String appName, - BaseArtifactService artifactService, - BaseSessionService sessionService, - @Nullable BaseMemoryService memoryService, - List plugins, - ResumabilityConfig resumabilityConfig) { - this( - agent, - appName, - artifactService, - sessionService, - memoryService, - plugins, - resumabilityConfig, - null, - null); + this(agent, appName, artifactService, sessionService, memoryService, plugins, null, null); } /** @@ -280,7 +231,6 @@ protected Runner( BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, List plugins, - ResumabilityConfig resumabilityConfig, @Nullable EventsCompactionConfig eventsCompactionConfig, @Nullable ContextCacheConfig contextCacheConfig) { this.agent = agent; @@ -289,7 +239,6 @@ protected Runner( this.sessionService = sessionService; this.memoryService = memoryService; this.pluginManager = new PluginManager(plugins); - this.resumabilityConfig = resumabilityConfig; this.eventsCompactionConfig = createEventsCompactionConfig(agent, eventsCompactionConfig); this.contextCacheConfig = contextCacheConfig; } @@ -651,7 +600,6 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .pluginManager(this.pluginManager) .agent(rootAgent) .session(session) - .resumabilityConfig(this.resumabilityConfig) .eventsCompactionConfig(this.eventsCompactionConfig) .contextCacheConfig(this.contextCacheConfig) .agent(this.findAgentToRun(session, rootAgent)); diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index bbfbb74bd..0237261c5 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -20,22 +20,15 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; -import com.google.adk.events.Event; -import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.summarizer.EventsCompactionConfig; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; -import com.google.genai.types.Part; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -186,12 +179,12 @@ public void testToBuilder() { assertThat(copiedContext.activeStreamingTools()) .isEqualTo(originalContext.activeStreamingTools()); assertThat(copiedContext.callbackContextData()) - .isSameInstanceAs(originalContext.callbackContextData()); + .isEqualTo(originalContext.callbackContextData()); } @Test public void testBuildWithCallbackContextData() { - Map data = new ConcurrentHashMap<>(); + ConcurrentHashMap data = new ConcurrentHashMap<>(); data.put("key", "value"); InvocationContext context = InvocationContext.builder() @@ -203,7 +196,6 @@ public void testBuildWithCallbackContextData() { .build(); assertThat(context.callbackContextData()).isEqualTo(data); - assertThat(context.callbackContextData()).isSameInstanceAs(data); } @Test @@ -443,7 +435,7 @@ public void testEquals_differentValues() { .userContent(userContent) .runConfig(runConfig) .endInvocation(false) - .callbackContextData(ImmutableMap.of("key", "value")) + .callbackContextData(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) .build(); assertThat(context.equals(contextWithDiffCallbackContextData)).isFalse(); } @@ -508,161 +500,11 @@ public void testHashCode_differentValues() { .userContent(userContent) .runConfig(runConfig) .endInvocation(false) - .callbackContextData(ImmutableMap.of("key", "value")) + .callbackContextData(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) .build(); assertThat(context.hashCode()).isNotEqualTo(contextWithDiffCallbackContextData.hashCode()); } - @Test - public void isResumable_whenResumabilityConfigIsNotResumable_isFalse() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(false)) - .build(); - assertThat(context.isResumable()).isFalse(); - } - - @Test - public void isResumable_whenResumabilityConfigIsResumable_isTrue() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - assertThat(context.isResumable()).isTrue(); - } - - @Test - public void shouldPauseInvocation_whenNotResumable_isFalse() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(false)) - .build(); - Event event = - Event.builder() - .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); - } - - @Test - public void shouldPauseInvocation_whenResumableAndNoLongRunningToolIds_isFalse() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = - Event.builder() - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); - } - - @Test - public void shouldPauseInvocation_whenResumableAndNoFunctionCalls_isFalse() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = Event.builder().longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))).build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); - } - - @Test - public void shouldPauseInvocation_whenResumableAndNoMatchingFunctionCallId_isFalse() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = - Event.builder() - .longRunningToolIds(Optional.of(ImmutableSet.of("fc2"))) - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); - } - - @Test - public void shouldPauseInvocation_whenResumableAndMatchingFunctionCallId_isTrue() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .memoryService(mockMemoryService) - .agent(mockAgent) - .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = - Event.builder() - .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isTrue(); - } - @Test public void incrementLlmCallsCount_whenLimitNotExceeded_doesNotThrow() throws Exception { InvocationContext context = @@ -729,26 +571,6 @@ public void testSessionGetters() { assertThat(context.userId()).isEqualTo("test-user"); } - @Test - public void testAgentStatesAndEndOfAgents() { - BaseAgentState mockState = mock(BaseAgentState.class); - ImmutableMap states = ImmutableMap.of("agent1", mockState); - ImmutableMap endOfAgents = ImmutableMap.of("agent1", true); - - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .agent(mockAgent) - .session(session) - .agentStates(states) - .endOfAgents(endOfAgents) - .build(); - - assertThat(context.agentStates()).isEqualTo(states); - assertThat(context.endOfAgents()).isEqualTo(endOfAgents); - } - @Test public void testSetEndInvocation() { InvocationContext context = @@ -936,58 +758,6 @@ public void testDeprecatedConstructor_11params() { assertThat(context.endInvocation()).isTrue(); } - @Test - public void populateAgentStates_populatesAgentStatesAndEndOfAgents() { - InvocationContext context = - InvocationContext.builder() - .sessionService(mockSessionService) - .artifactService(mockArtifactService) - .agent(mockAgent) - .session(session) - .invocationId(testInvocationId) - .build(); - - BaseAgentState agent1State = mock(BaseAgentState.class); - ConcurrentHashMap agent1StateMap = new ConcurrentHashMap<>(); - agent1StateMap.put("agent1", agent1State); - Event event1 = - Event.builder() - .invocationId(testInvocationId) - .author("agent1") - .actions(EventActions.builder().agentState(agent1StateMap).endOfAgent(true).build()) - .build(); - Event event2 = - Event.builder() - .invocationId("other-invocation-id") - .author("agent2") - .actions(EventActions.builder().endOfAgent(true).build()) - .build(); - Event event3 = - Event.builder() - .invocationId(testInvocationId) - .author("agent3") - .actions(EventActions.builder().endOfAgent(false).build()) - .build(); - BaseAgentState agent4State = mock(BaseAgentState.class); - ConcurrentHashMap agent4StateMap = new ConcurrentHashMap<>(); - agent4StateMap.put("agent4", agent4State); - Event event4 = - Event.builder() - .invocationId(testInvocationId) - .author("agent4") - .actions(EventActions.builder().agentState(agent4StateMap).endOfAgent(false).build()) - .build(); - Event event5 = Event.builder().invocationId(testInvocationId).author("agent5").build(); - - context.populateAgentStates(ImmutableList.of(event1, event2, event3, event4, event5)); - - assertThat(context.agentStates()).hasSize(2); - assertThat(context.agentStates()).containsEntry("agent1", agent1State); - assertThat(context.agentStates()).containsEntry("agent4", agent4State); - assertThat(context.endOfAgents()).hasSize(1); - assertThat(context.endOfAgents()).containsEntry("agent1", true); - } - @Test public void build_missingInvocationId_null_throwsException() { InvocationContext.Builder builder = diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 86b0a81ec..421b79abb 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -36,7 +36,6 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -929,48 +928,6 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } - @Test - public void resumabilityConfig_isResumable_isTrueInInvocationContext() { - ArgumentCaptor contextCaptor = - ArgumentCaptor.forClass(InvocationContext.class); - when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); - Runner runner = - Runner.builder() - .app( - App.builder() - .name("test") - .rootAgent(agent) - .plugins(ImmutableList.of(plugin)) - .resumabilityConfig(new ResumabilityConfig(true)) - .build()) - .build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - var unused = - runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(contextCaptor.getValue().isResumable()).isTrue(); - } - - @Test - public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { - ArgumentCaptor contextCaptor = - ArgumentCaptor.forClass(InvocationContext.class); - when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); - Runner runner = - Runner.builder() - .app( - App.builder() - .name("test") - .rootAgent(agent) - .plugins(ImmutableList.of(plugin)) - .resumabilityConfig(new ResumabilityConfig(false)) - .build()) - .build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - var unused = - runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(contextCaptor.getValue().isResumable()).isFalse(); - } - @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); From 936471ef894ae8a8adef9a5c3a2ed2cd0f9713a1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 12 Feb 2026 14:59:52 -0800 Subject: [PATCH 55/63] refactor: Creating a long term Runner.runAsyncImpl PiperOrigin-RevId: 869400120 --- .../java/com/google/adk/runner/Runner.java | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 0095d3fb6..3ff778011 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -381,7 +381,7 @@ public Flowable runAsync( new IllegalArgumentException( String.format("Session not found: %s for user %s", sessionId, userId))); })) - .flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta)); + .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -415,6 +415,23 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { + return runAsyncImpl(session, newMessage, runConfig, stateDelta); + } + + /** + * Runs the agent asynchronously using a provided Session object. + * + * @param session The session to run the agent in. + * @param newMessage The new message from the user to process. + * @param runConfig Configuration for the agent run. + * @param stateDelta Optional map of state updates to merge into the session for this run. + * @return A Flowable stream of {@link Event} objects generated by the agent during execution. + */ + protected Flowable runAsyncImpl( + Session session, + Content newMessage, + RunConfig runConfig, + @Nullable Map stateDelta) { Span span = Tracing.getTracer().spanBuilder("invocation").setParent(Context.current()).startSpan(); Context spanContext = Context.current().with(span); @@ -688,6 +705,7 @@ public Flowable runLive( * * @return stream of generated events. */ + @Deprecated(since = "0.5.0", forRemoval = true) public Flowable runWithSessionId( String sessionId, Content newMessage, RunConfig runConfig) { // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. From 799646e1380839d93bb7053bfd12a1f90ca04878 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Feb 2026 05:34:20 -0800 Subject: [PATCH 56/63] chore: update pom.xml to prepare it for releasing PiperOrigin-RevId: 869686128 --- pom.xml | 60 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/pom.xml b/pom.xml index 89f0d2c0f..01dbd4201 100644 --- a/pom.xml +++ b/pom.xml @@ -405,15 +405,6 @@ - - org.sonatype.central - central-publishing-maven-plugin - 0.8.0 - true - - central - - com.spotify.fmt fmt-maven-plugin @@ -471,6 +462,34 @@ + + release-sonatype + + + + org.sonatype.central + central-publishing-maven-plugin + 0.8.0 + true + + central + + + + + + + central + Maven Central Repository + https://central.sonatype.com/api/v1/publisher + + + central + Maven Central Repository Snapshots + https://central.sonatype.com/repository/maven-snapshots/ + + + release @@ -480,14 +499,7 @@ maven-gpg-plugin 3.2.7 - ${gpg.keyname} - ${gpg.passphrase} - - --batch - --yes - --pinentry-mode - loopback - + bc @@ -545,16 +557,4 @@ https://www.google.com - - - central - Maven Central Repository - https://central.sonatype.com/api/v1/publisher - - - central - Maven Central Repository Snapshots - https://central.sonatype.com/repository/maven-snapshots/ - - - \ No newline at end of file + From e1b9f5a298fdb14784c50de75966805d7c214161 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Feb 2026 05:39:07 -0800 Subject: [PATCH 57/63] chore: configure release-please github action PiperOrigin-RevId: 869687403 --- .github/workflows/release-please.yaml | 17 +++++++++++++++++ README.md | 4 ++-- core/src/main/java/com/google/adk/Version.java | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/release-please.yaml diff --git a/.github/workflows/release-please.yaml b/.github/workflows/release-please.yaml new file mode 100644 index 000000000..6d3142907 --- /dev/null +++ b/.github/workflows/release-please.yaml @@ -0,0 +1,17 @@ +'on': + push: + branches: + - main + workflow_dispatch: {} +permissions: + contents: write + issues: write + pull-requests: write +name: release-please +jobs: + release-please: + runs-on: ubuntu-latest + steps: + - uses: googleapis/release-please-action@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index 11a3fbd5c..bd39b181e 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ debugging, versioning, and deployment anywhere – from your laptop to the cloud If you're using Maven, add the following to your dependencies: - + ```xml @@ -60,7 +60,7 @@ If you're using Maven, add the following to your dependencies: ``` - + To instead use an unreleased version, you could use ; see for an example illustrating this. diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index fbc73039e..8b10341ac 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "0.5.0"; + public static final String JAVA_ADK_VERSION = "0.5.0"; // x-release-please-released-version private Version() {} } From 62c1350a7ea3e7b664dbdf5375ccf6263e7e449d Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Feb 2026 06:11:49 -0800 Subject: [PATCH 58/63] chore: add release-please-config.json PiperOrigin-RevId: 869697412 --- release-please-config.json | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 release-please-config.json diff --git a/release-please-config.json b/release-please-config.json new file mode 100644 index 000000000..6b3a1302c --- /dev/null +++ b/release-please-config.json @@ -0,0 +1,11 @@ +{ + "packages": { + ".": {} + }, + "include-component-in-tag": false, + "release-type": "maven", + "extra-files": [ + "core/src/main/java/com/google/adk/Version.java", + "README.md" + ] +} From e277ee9850365280e4742990ed9d9597277fa6f3 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Feb 2026 06:21:08 -0800 Subject: [PATCH 59/63] chore: remove release-please bot configuration PiperOrigin-RevId: 869700933 --- .github/release-please.yml | 6 ------ .github/release-trigger.yml | 1 - 2 files changed, 7 deletions(-) delete mode 100644 .github/release-please.yml delete mode 100644 .github/release-trigger.yml diff --git a/.github/release-please.yml b/.github/release-please.yml deleted file mode 100644 index 168307d7f..000000000 --- a/.github/release-please.yml +++ /dev/null @@ -1,6 +0,0 @@ -releaseType: maven -handleGHRelease: true -bumpMinorPreMajor: true -extraFiles: - - core/src/main/java/com/google/adk/Version.java - - README.md \ No newline at end of file diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml deleted file mode 100644 index 7fe362257..000000000 --- a/.github/release-trigger.yml +++ /dev/null @@ -1 +0,0 @@ -enabled: true \ No newline at end of file From bd1b82e39d8443135141282f7a30e442475f7482 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Fri, 13 Feb 2026 06:31:25 -0800 Subject: [PATCH 60/63] chore: add .release-please-manifest.json PiperOrigin-RevId: 869704217 --- .release-please-manifest.json | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .release-please-manifest.json diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 000000000..600835dbf --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "0.5.0" +} From 284aff6757a6f2f0a4ba015b7dd2c832dcc26d79 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Feb 2026 07:03:11 -0800 Subject: [PATCH 61/63] refactor: Simplifiying Tracing code Refactor the tracing implementation within the Google ADK to simplify how OpenTelemetry spans are managed, especially within RxJava streams. The key changes include: 1. **Introducing `Tracing.TracerProvider`**: A new set of RxJava transformers (`FlowableTransformer`, `SingleTransformer`, `MaybeTransformer`, `CompletableTransformer`) is added in `Tracing.java`. These transformers, created via `Tracing.trace()` methods, handle the lifecycle of OpenTelemetry spans, including span creation, making the span current, and ending the span upon stream completion or error. 2. **Simplifying Tracing Calls**: Instead of manually creating and managing spans with `Tracer`, `Span`, and `Scope`, various parts of the codebase now use the `.compose(Tracing.trace(...))` operator on RxJava streams. This is applied in: * `BaseAgent.java`: For agent invocations using `Tracing.traceAgent`. * `BaseLlmFlow.java`: For LLM calls (`call_llm`) and sending data (`send_data`). * `Functions.java`: For tool responses and tool calls. * `Runner.java`: For overall invocation spans in `runAsync` and `runLive`. 3. **Centralized Attribute Setting**: Helper methods like `getValidCurrentSpan`, `setInvocationAttributes`, `setToolExecutionAttributes`, and `setJsonAttribute` are added to `Tracing.java` to encapsulate and standardize how attributes are set on spans, including handling JSON serialization and checks for valid spans. These changes aim to reduce tracing-related boilerplate, improve consistency, and make the tracing logic more robust by tying span lifetimes to RxJava stream lifecycles. PiperOrigin-RevId: 869714739 --- .../java/com/google/adk/agents/BaseAgent.java | 89 ++-- .../adk/flows/llmflows/BaseLlmFlow.java | 264 +++++------ .../google/adk/flows/llmflows/Functions.java | 197 ++++---- .../java/com/google/adk/runner/Runner.java | 203 ++++---- .../com/google/adk/telemetry/Tracing.java | 444 ++++++++++++------ .../adk/telemetry/ContextPropagationTest.java | 33 ++ 6 files changed, 657 insertions(+), 573 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 72fc5883a..226e61abe 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,9 +29,6 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; -import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.Tracer; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -315,44 +312,37 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { InvocationContext invocationContext = createInvocationContext(parentContext); - Span span = - tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); - Tracing.traceAgentInvocation(span, name(), description(), invocationContext); - Context spanContext = Context.current().with(span); - - return Tracing.traceFlowable( - spanContext, - span, - () -> - callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runImplementation.apply(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), - afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - })); + + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEventOpt -> { + if (invocationContext.endInvocation()) { + return Flowable.fromOptional(beforeEventOpt); + } + + Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); + Flowable mainEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)); + Flowable afterEvents = + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::fromOptional)); + + return Flowable.concat(beforeEvents, mainEvents, afterEvents); + }) + .compose( + Tracing.traceAgent( + "invoke_agent " + name(), name(), description(), invocationContext)); }); } @@ -364,11 +354,8 @@ private Flowable run( */ private ImmutableList>> beforeCallbacksToFunctions( Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) - .collect(toImmutableList()); + return callbacksToFunctions( + ctx -> pluginManager.beforeAgentCallback(this, ctx), callbacks, c -> c::call); } /** @@ -379,10 +366,15 @@ private ImmutableList>> beforeCallbacks */ private ImmutableList>> afterCallbacksToFunctions( Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) + return callbacksToFunctions( + ctx -> pluginManager.afterAgentCallback(this, ctx), callbacks, c -> c::call); + } + + private ImmutableList>> callbacksToFunctions( + Function> pluginCallback, + List callbacks, + Function>> mapper) { + return Stream.concat(Stream.of(pluginCallback), callbacks.stream().map(mapper)) .collect(toImmutableList()); } @@ -523,8 +515,7 @@ public B subAgents(List subAgents) { @CanIgnoreReturnValue public B subAgents(BaseAgent... subAgents) { - this.subAgents = ImmutableList.copyOf(subAgents); - return self(); + return subAgents(ImmutableList.copyOf(subAgents)); } @CanIgnoreReturnValue diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fd383baf1..6ca49ee62 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -164,10 +164,7 @@ protected Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, - LlmRequest llmRequest, - Event eventForCallbackUsage, - Context parentTracingContext) { + InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); @@ -182,45 +179,29 @@ private Flowable callLlm( agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return Flowable.defer( - () -> { - Span llmCallSpan = - Tracing.getTracer() - .spanBuilder("call_llm") - .setParent(parentTracingContext) - .startSpan(); - - try (Scope scope = llmCallSpan.makeCurrent()) { - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, - llmRequestBuilder, - eventForCallbackUsage, - exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> { - try (Scope innerScope = llmCallSpan.makeCurrent()) { - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp); - } - }) - .doOnError( - error -> { - llmCallSpan.setStatus(StatusCode.ERROR, error.getMessage()); - llmCallSpan.recordException(error); - }) - .doFinally(llmCallSpan::end); - } + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); }) + .compose(Tracing.trace("call_llm")) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) @@ -343,79 +324,79 @@ private Single handleAfterModelCallback( * @throws IllegalStateException if a transfer agent is specified but not found. */ private Flowable runOneStep(InvocationContext context) { - Context parentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = preprocess(context, llmRequestRef); - return preprocessEvents.concatWith( - Flowable.defer( - () -> { - LlmRequest llmRequestAfterPreprocess = llmRequestRef.get(); - if (context.endInvocation()) { - logger.debug("End invocation requested during preprocessing."); - return Flowable.empty(); - } - - try { - context.incrementLlmCallsCount(); - } catch (LlmCallsLimitExceededException e) { - logger.error("LLM calls limit exceeded.", e); - return Flowable.error(e); - } - - final Event mutableEventTemplate = - Event.builder() - .id(Event.generateEventId()) - .invocationId(context.invocationId()) - .author(context.agent().name()) - .branch(context.branch()) - .build(); - // Explicitly set the event timestamp to 0 so the postprocessing logic would generate - // events with fresh timestamp. - mutableEventTemplate.setTimestamp(0L); - - return callLlm( - context, llmRequestAfterPreprocess, mutableEventTemplate, parentContext) - .concatMap( - llmResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - mutableEventTemplate.setId(Event.generateEventId()); - logger.debug( - "Updated mutableEventTemplate ID from {} to {} for" - + " next LlmResponse", - oldId, - mutableEventTemplate.id()); - }); + return Flowable.defer( + () -> { + Context currentContext = Context.current(); + return preprocess(context, llmRequestRef) + .concatWith( + Flowable.defer( + () -> { + LlmRequest llmRequestAfterPreprocess = llmRequestRef.get(); + if (context.endInvocation()) { + logger.debug("End invocation requested during preprocessing."); + return Flowable.empty(); } - }) - .concatMap( - event -> { - Flowable postProcessedEvents = Flowable.just(event); - if (event.actions().transferToAgent().isPresent()) { - String agentToTransfer = event.actions().transferToAgent().get(); - logger.debug("Transferring to agent: {}", agentToTransfer); - BaseAgent rootAgent = context.agent().rootAgent(); - Optional nextAgent = rootAgent.findAgent(agentToTransfer); - if (nextAgent.isEmpty()) { - String errorMsg = "Agent not found for transfer: " + agentToTransfer; - logger.error(errorMsg); - return postProcessedEvents.concatWith( - Flowable.error(new IllegalStateException(errorMsg))); - } - return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.get().runAsync(context))); + + try { + context.incrementLlmCallsCount(); + } catch (LlmCallsLimitExceededException e) { + logger.error("LLM calls limit exceeded.", e); + return Flowable.error(e); } - return postProcessedEvents; - }); - })); + + final Event mutableEventTemplate = + Event.builder() + .id(Event.generateEventId()) + .invocationId(context.invocationId()) + .author(context.agent().name()) + .branch(context.branch()) + .build(); + mutableEventTemplate.setTimestamp(0L); + + return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + .concatMap( + llmResponse -> { + try (Scope postScope = currentContext.makeCurrent()) { + return postprocess( + context, + mutableEventTemplate, + llmRequestAfterPreprocess, + llmResponse) + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug( + "Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); + }); + } + }) + .concatMap( + event -> { + Flowable postProcessedEvents = Flowable.just(event); + if (event.actions().transferToAgent().isPresent()) { + String agentToTransfer = + event.actions().transferToAgent().get(); + BaseAgent rootAgent = context.agent().rootAgent(); + Optional nextAgent = + rootAgent.findAgent(agentToTransfer); + if (nextAgent.isEmpty()) { + logger.error("Agent not found: {}", agentToTransfer); + return postProcessedEvents.concatWith( + Flowable.error( + new IllegalStateException( + "Agent not found: " + agentToTransfer))); + } + return postProcessedEvents.concatWith( + Flowable.defer(() -> nextAgent.get().runAsync(context))); + } + return postProcessedEvents; + }); + })); + }); } /** @@ -436,7 +417,6 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return currentStepEvents; } - Context parentContext = Context.current(); return currentStepEvents.concatWith( currentStepEvents .toList() @@ -451,12 +431,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return Flowable.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - return run(invocationContext, stepsCompleted + 1); - } - }); + return run(invocationContext, stepsCompleted + 1); } })); } @@ -492,40 +467,25 @@ public Flowable runLive(InvocationContext invocationContext) { Completable historySent = llmRequestAfterPreprocess.contents().isEmpty() ? Completable.complete() - : Completable.defer( - () -> { - Span sendDataSpan = - Tracing.getTracer() - .spanBuilder("send_data") - .setParent(Context.current()) - .startSpan(); - try (Scope scope = sendDataSpan.makeCurrent()) { - return connection - .sendHistory(llmRequestAfterPreprocess.contents()) - .doOnComplete( - () -> { - try (Scope innerScope = sendDataSpan.makeCurrent()) { - Tracing.traceSendData( - invocationContext, - eventIdForSendData, - llmRequestAfterPreprocess.contents()); - } - }) - .doOnError( - error -> { - sendDataSpan.setStatus( - StatusCode.ERROR, error.getMessage()); - sendDataSpan.recordException(error); - try (Scope innerScope = sendDataSpan.makeCurrent()) { - Tracing.traceSendData( - invocationContext, - eventIdForSendData, - llmRequestAfterPreprocess.contents()); - } - }) - .doFinally(sendDataSpan::end); - } - }); + : connection + .sendHistory(llmRequestAfterPreprocess.contents()) + .doOnComplete( + () -> + Tracing.traceSendData( + invocationContext, + eventIdForSendData, + llmRequestAfterPreprocess.contents())) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + Tracing.traceSendData( + invocationContext, + eventIdForSendData, + llmRequestAfterPreprocess.contents()); + }) + .compose(Tracing.trace("send_data")); Flowable liveRequests = invocationContext diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 26f14d24b..82813defa 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -3,7 +3,6 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 @@ -42,7 +41,6 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; @@ -150,8 +148,9 @@ public static Maybe handleFunctionCalls( } } + Context parentContext = Context.current(); Function> functionCallMapper = - getFunctionCallMapper(invocationContext, tools, toolConfirmations, false); + getFunctionCallMapper(invocationContext, tools, toolConfirmations, false, parentContext); Observable functionResponseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { @@ -177,14 +176,9 @@ public static Maybe handleFunctionCalls( var mergedEvent = maybeMergedEvent.get(); if (events.size() > 1) { - Tracer tracer = Tracing.getTracer(); - Span mergedSpan = - tracer.spanBuilder("tool_response").setParent(Context.current()).startSpan(); - try (Scope scope = mergedSpan.makeCurrent()) { - Tracing.traceToolResponse(mergedEvent.id(), mergedEvent); - } finally { - mergedSpan.end(); - } + return Maybe.just(mergedEvent) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) + .compose(Tracing.trace("tool_response", parentContext)); } return Maybe.just(mergedEvent); }); @@ -216,8 +210,9 @@ public static Maybe handleFunctionCallsLive( } } + Context parentContext = Context.current(); Function> functionCallMapper = - getFunctionCallMapper(invocationContext, tools, toolConfirmations, true); + getFunctionCallMapper(invocationContext, tools, toolConfirmations, true, parentContext); Observable responseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { @@ -244,8 +239,8 @@ private static Function> getFunctionCallMapper( InvocationContext invocationContext, Map tools, Map toolConfirmations, - boolean isLive) { - Context parentContext = Context.current(); + boolean isLive, + Context parentContext) { return functionCall -> Maybe.defer( () -> { @@ -272,13 +267,20 @@ private static Function> getFunctionCallMapper( tool, toolContext, functionCall, - functionArgs) - : callTool(tool, functionArgs, toolContext); + functionArgs, + parentContext) + : callTool(tool, functionArgs, toolContext, parentContext); } })); return postProcessFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext); + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); } }); } @@ -292,7 +294,8 @@ private static Maybe> processFunctionLive( BaseTool tool, ToolContext toolContext, FunctionCall functionCall, - Map args) { + Map args, + Context parentContext) { // Case 1: Handle a call to stopStreaming if (functionCall.name().get().equals("stopStreaming") && args.containsKey("functionName")) { String functionNameToStop = (String) args.get("functionName"); @@ -360,7 +363,7 @@ private static Maybe> processFunctionLive( } // Case 3: Fallback for regular, non-streaming tools - return callTool(tool, args, toolContext); + return callTool(tool, args, toolContext, parentContext); } public static Set getLongRunningFunctionCalls( @@ -383,47 +386,54 @@ private static Maybe postProcessFunctionResult( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext) { - Context parentContext = Context.current(); + ToolContext toolContext, + boolean isLive, + Context parentContext) { return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( - t -> - Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - return handleOnToolErrorCallback( - invocationContext, tool, functionArgs, toolContext, t); - } - }) - .map(Optional::ofNullable) - .switchIfEmpty(Single.error(t))) + t -> { + Maybe> errorCallbackResult = + handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); + } + return mappedResult.switchIfEmpty(Single.error(t)); + }) .flatMapMaybe( optionalInitialResult -> { try (Scope scope = parentContext.makeCurrent()) { Map initialFunctionResult = optionalInitialResult.orElse(null); - Maybe> afterToolResultMaybe = - maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult); - - return afterToolResultMaybe + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) .map(Optional::of) .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) .flatMapMaybe( finalOptionalResult -> { - try (Scope innerScope = parentContext.makeCurrent()) { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, + finalFunctionResult, + toolContext, + invocationContext)) + .compose( + Tracing.trace( + "tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); }); } }); @@ -579,29 +589,21 @@ private static Maybe> maybeInvokeAfterToolCall( } private static Maybe> callTool( - BaseTool tool, Map args, ToolContext toolContext) { - Tracer tracer = Tracing.getTracer(); - Context parentContext = Context.current(); - return Maybe.defer( - () -> { - Span span = - tracer - .spanBuilder("tool_call [" + tool.name() + "]") - .setParent(parentContext) - .startSpan(); - try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall( - tool.name(), tool.description(), tool.getClass().getSimpleName(), args); - return tool.runAsync(args, toolContext) - .toMaybe() - .doOnError(span::recordException) - .doFinally(span::end); - } catch (RuntimeException e) { - span.recordException(e); - span.end(); - return Maybe.error(new RuntimeException("Failed to call tool: " + tool.name(), e)); - } - }); + BaseTool tool, Map args, ToolContext toolContext, Context parentContext) { + return tool.runAsync(args, toolContext) + .toMaybe() + .doOnSubscribe( + d -> + Tracing.traceToolCall( + tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) + .doOnError(t -> Span.current().recordException(t)) + .compose(Tracing.trace("tool_call [" + tool.name() + "]", parentContext)) + .onErrorResumeNext( + e -> + Maybe.error( + e instanceof RuntimeException runtimeException + ? runtimeException + : new RuntimeException("Failed to call tool: " + tool.name(), e))); } private static Event buildResponseEvent( @@ -609,42 +611,27 @@ private static Event buildResponseEvent( Map response, ToolContext toolContext, InvocationContext invocationContext) { - Tracer tracer = Tracing.getTracer(); - Span span = - tracer - .spanBuilder("tool_response [" + tool.name() + "]") - .setParent(Context.current()) - .startSpan(); - try (Scope scope = span.makeCurrent()) { - // use a empty placeholder response if tool response is null. - if (response == null) { - response = new HashMap<>(); - } - - Part partFunctionResponse = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .id(toolContext.functionCallId().orElse("")) - .name(tool.name()) - .response(response) - .build()) - .build(); - - Event event = - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(invocationContext.agent().name()) - .branch(invocationContext.branch()) - .content(Content.builder().role("user").parts(partFunctionResponse).build()) - .actions(toolContext.eventActions()) - .build(); - Tracing.traceToolResponse(event.id(), event); - return event; - } finally { - span.end(); - } + // use an empty placeholder response if tool response is null. + Map finalResponse = response != null ? response : new HashMap<>(); + + Part partFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(toolContext.functionCallId().orElse("")) + .name(tool.name()) + .response(finalResponse) + .build()) + .build(); + + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(invocationContext.agent().name()) + .branch(invocationContext.branch()) + .content(Content.builder().role("user").parts(partFunctionResponse).build()) + .actions(toolContext.eventActions()) + .build(); } /** diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 3ff778011..2f8201ba1 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -51,7 +51,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -432,75 +431,60 @@ protected Flowable runAsyncImpl( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Span span = - Tracing.getTracer().spanBuilder("invocation").setParent(Context.current()).startSpan(); - Context spanContext = Context.current().with(span); - - try { - BaseAgent rootAgent = this.agent; - String invocationId = InvocationContext.newInvocationContextId(); - - // Create initial context - InvocationContext initialContext = - newInvocationContextBuilder(session) - .invocationId(invocationId) - .runConfig(runConfig) - .userContent(newMessage) - .build(); - - return Tracing.traceFlowable( - spanContext, - span, - () -> - Flowable.defer( - () -> - this.pluginManager - .onUserMessageCallback(initialContext, newMessage) - .defaultIfEmpty(newMessage) - .flatMap( - content -> - (content != null) - ? appendNewMessageToSession( - session, - content, - initialContext, - runConfig.saveInputBlobsAsArtifacts(), - stateDelta) - : Single.just(null)) - .flatMapPublisher( - event -> { - if (event == null) { - return Flowable.empty(); - } - // Get the updated session after the message and state delta are - // applied - return this.sessionService - .getSession( - session.appName(), - session.userId(), - session.id(), - Optional.empty()) - .flatMapPublisher( - updatedSession -> - runAgentWithFreshSession( - session, - updatedSession, - event, - invocationId, - runConfig, - rootAgent)); - })) - .doOnError( - throwable -> { - span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); - span.recordException(throwable); - })); - } catch (Throwable t) { - span.setStatus(StatusCode.ERROR, "Error during runAsync synchronous setup"); - span.recordException(t); - span.end(); - return Flowable.error(t); - } + return Flowable.defer( + () -> { + BaseAgent rootAgent = this.agent; + String invocationId = InvocationContext.newInvocationContextId(); + + // Create initial context + InvocationContext initialContext = + newInvocationContextBuilder(session) + .invocationId(invocationId) + .runConfig(runConfig) + .userContent(newMessage) + .build(); + + return this.pluginManager + .onUserMessageCallback(initialContext, newMessage) + .defaultIfEmpty(newMessage) + .flatMap( + content -> + (content != null) + ? appendNewMessageToSession( + session, + content, + initialContext, + runConfig.saveInputBlobsAsArtifacts(), + stateDelta) + : Single.just(null)) + .flatMapPublisher( + event -> { + if (event == null) { + return Flowable.empty(); + } + // Get the updated session after the message and state delta are + // applied + return this.sessionService + .getSession( + session.appName(), session.userId(), session.id(), Optional.empty()) + .flatMapPublisher( + updatedSession -> + runAgentWithFreshSession( + session, + updatedSession, + event, + invocationId, + runConfig, + rootAgent)); + }); + }) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.trace("invocation")); } private Flowable runAgentWithFreshSession( @@ -629,52 +613,39 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - Span span = - Tracing.getTracer().spanBuilder("invocation").setParent(Context.current()).startSpan(); - Context spanContext = Context.current().with(span); - - try { - InvocationContext invocationContext = - newInvocationContextForLive(session, Optional.of(liveRequestQueue), runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - - return invocationContextSingle.flatMapPublisher( - updatedInvocationContext -> - Tracing.traceFlowable( - spanContext, - span, - () -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event)) - .onErrorResumeNext( - throwable -> { - span.setStatus( - StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - span.end(); - return Flowable.error(throwable); - }))); - } catch (Throwable t) { - span.setStatus(StatusCode.ERROR, "Error during runLive synchronous setup"); - span.recordException(t); - span.end(); - return Flowable.error(t); - } + return Flowable.defer( + () -> { + InvocationContext invocationContext = + newInvocationContextForLive(session, Optional.of(liveRequestQueue), runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }) + .compose(Tracing.trace("invocation")); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 36c6e3e58..07a640c37 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -19,7 +19,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; import com.google.adk.JsonBaseModel; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; @@ -27,6 +26,7 @@ import com.google.adk.models.LlmResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; @@ -36,14 +36,27 @@ import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableSource; +import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.FlowableTransformer; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.MaybeSource; +import io.reactivex.rxjava3.core.MaybeTransformer; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleSource; +import io.reactivex.rxjava3.core.SingleTransformer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Supplier; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -105,9 +118,6 @@ public class Tracing { private static final AttributeKey ADK_DATA = AttributeKey.stringKey("gcp.vertex.agent.data"); - private static final TypeReference> MAP_TYPE_REFERENCE = - new TypeReference>() {}; - @SuppressWarnings("NonFinalStaticField") private static Tracer tracer = GlobalOpenTelemetry.getTracer("gcp.vertex.agent"); @@ -117,6 +127,54 @@ public class Tracing { private Tracing() {} + private static Optional getValidCurrentSpan(String methodName) { + Span span = Span.current(); + if (!span.getSpanContext().isValid()) { + log.trace("{}: No valid span in current context.", methodName); + return Optional.empty(); + } + return Optional.of(span); + } + + private static void setInvocationAttributes( + Span span, InvocationContext invocationContext, String eventId) { + span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); + if (eventId != null && !eventId.isEmpty()) { + span.setAttribute(ADK_EVENT_ID, eventId); + } + + if (invocationContext.session() != null && invocationContext.session().id() != null) { + span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); + } else { + log.trace( + "InvocationContext session or session ID is null, cannot set {}", + ADK_SESSION_ID.getKey()); + } + } + + private static void setToolExecutionAttributes(Span span) { + span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); + } + + private static void setJsonAttribute(Span span, AttributeKey key, Object value) { + if (!CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + span.setAttribute(key, "{}"); + return; + } + try { + String json = + (value instanceof String stringValue) + ? stringValue + : JsonBaseModel.getMapper().writeValueAsString(value); + span.setAttribute(key, json); + } catch (JsonProcessingException | RuntimeException e) { + log.warn("Failed to serialize {} to JSON", key.getKey(), e); + span.setAttribute(key, "{\"error\": \"serialization failed\"}"); + } + } + /** Sets the OpenTelemetry instance to be used for tracing. This is for testing purposes only. */ public static void setTracerForTesting(Tracer tracer) { Tracing.tracer = tracer; @@ -148,27 +206,16 @@ public static void traceAgentInvocation( */ public static void traceToolCall( String toolName, String toolDescription, String toolType, Map args) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceToolCall: No valid span in current context."); - return; - } + getValidCurrentSpan("traceToolCall") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { - try { - span.setAttribute(ADK_TOOL_CALL_ARGS, JsonBaseModel.getMapper().writeValueAsString(args)); - } catch (JsonProcessingException e) { - log.warn("traceToolCall: Failed to serialize tool call args to JSON", e); - } - } else { - span.setAttribute(ADK_TOOL_CALL_ARGS, "{}"); - } - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** @@ -178,49 +225,33 @@ public static void traceToolCall( * @param functionResponseEvent The function response event. */ public static void traceToolResponse(String eventId, Event functionResponseEvent) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceToolResponse: No valid span in current context."); - return; - } - - span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); - span.setAttribute(ADK_EVENT_ID, eventId); - - String toolCallId = ""; - Object toolResponse = ""; - - Optional optionalFunctionResponse = - functionResponseEvent.functionResponses().stream().findFirst(); - - if (optionalFunctionResponse.isPresent()) { - FunctionResponse functionResponse = optionalFunctionResponse.get(); - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); - - if (!(toolResponse instanceof Map)) { - toolResponse = ImmutableMap.of("result", toolResponse); - } - - if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { - try { - span.setAttribute( - ADK_TOOL_RESPONSE, JsonBaseModel.getMapper().writeValueAsString(toolResponse)); - } catch (JsonProcessingException e) { - log.warn("traceToolResponse: Failed to serialize tool response to JSON", e); - span.setAttribute(ADK_TOOL_RESPONSE, "{\"error\": \"serialization failed\"}"); - } - } else { - span.setAttribute(ADK_TOOL_RESPONSE, "{}"); - } - - // Setting empty llm request and response (as the AdkDevServer UI expects these) - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); + getValidCurrentSpan("traceToolResponse") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); + + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -265,68 +296,58 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceCallLlm: No valid span in current context."); - return; - } + getValidCurrentSpan("traceCallLlm") + .ifPresent( + span -> { + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest + .model() + .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); - span.setAttribute(ADK_EVENT_ID, eventId); + setInvocationAttributes(span, invocationContext, eventId); - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); - } else { - log.trace( - "traceCallLlm: InvocationContext session or session ID is null, cannot set" - + " gcp.vertex.agent.session_id"); - } + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { - try { - span.setAttribute( - ADK_LLM_REQUEST, - JsonBaseModel.getMapper().writeValueAsString(buildLlmRequestForTrace(llmRequest))); - span.setAttribute(ADK_LLM_RESPONSE, llmResponse.toJson()); - } catch (JsonProcessingException e) { - log.warn("traceCallLlm: Failed to serialize LlmRequest or LlmResponse to JSON", e); - } - } else { - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); - } - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() + llmRequest + .config() .ifPresent( - maxTokens -> - span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() + config -> { + config + .topP() + .ifPresent( + topP -> + span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute( + GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + usage -> { + usage + .promptTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> + span.setAttribute( + GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -338,34 +359,17 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceSendData: No valid span in current context."); - return; - } - - span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); - if (eventId != null && !eventId.isEmpty()) { - span.setAttribute(ADK_EVENT_ID, eventId); - } - - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); - } - if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { - try { - ImmutableList> dataList = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(content -> content != null) - .map(content -> JsonBaseModel.getMapper().convertValue(content, MAP_TYPE_REFERENCE)) - .collect(toImmutableList()); - span.setAttribute(ADK_DATA, JsonBaseModel.toJsonString(dataList)); - } catch (IllegalStateException e) { - log.warn("traceSendData: Failed to serialize data to JSON", e); - } - } else { - span.setAttribute(ADK_DATA, "{}"); - } + getValidCurrentSpan("traceSendData") + .ifPresent( + span -> { + setInvocationAttributes(span, invocationContext, eventId); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** @@ -410,4 +414,142 @@ public static Flowable traceFlowable( span.end(); }); } + + /** + * Returns a transformer that traces the execution of an RxJava stream. + * + * @param spanName The name of the span to create. + * @param The type of the stream. + * @return A TracerProvider that can be used with .compose(). + */ + public static TracerProvider trace(String spanName) { + return new TracerProvider<>(spanName); + } + + /** + * Returns a transformer that traces the execution of an RxJava stream with an explicit parent + * context. + * + * @param spanName The name of the span to create. + * @param parentContext The explicit parent context for the span. + * @param The type of the stream. + * @return A TracerProvider that can be used with .compose(). + */ + public static TracerProvider trace(String spanName, Context parentContext) { + return new TracerProvider(spanName).setParent(parentContext); + } + + /** + * Returns a transformer that traces an agent invocation. + * + * @param spanName The name of the span to create. + * @param agentName The name of the agent. + * @param agentDescription The description of the agent. + * @param invocationContext The invocation context. + * @param The type of the stream. + * @return A TracerProvider configured for agent invocation. + */ + public static TracerProvider traceAgent( + String spanName, + String agentName, + String agentDescription, + InvocationContext invocationContext) { + return new TracerProvider(spanName) + .configure( + span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); + } + + /** + * A transformer that manages an OpenTelemetry span and scope for RxJava streams. + * + * @param The type of the stream. + */ + public static final class TracerProvider + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final String spanName; + private Context explicitParentContext; + private final List> spanConfigurers = new ArrayList<>(); + + private TracerProvider(String spanName) { + this.spanName = spanName; + } + + /** Configures the span created by this transformer. */ + @CanIgnoreReturnValue + public TracerProvider configure(Consumer configurer) { + spanConfigurers.add(configurer); + return this; + } + + /** Sets an explicit parent context for the span created by this transformer. */ + @CanIgnoreReturnValue + public TracerProvider setParent(Context parentContext) { + this.explicitParentContext = parentContext; + return this; + } + + private Context getParentContext() { + return explicitParentContext != null ? explicitParentContext : Context.current(); + } + + private final class TracingLifecycle { + private Span span; + private Scope scope; + + @SuppressWarnings("MustBeClosedChecker") + void start() { + span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); + spanConfigurers.forEach(c -> c.accept(span)); + scope = span.makeCurrent(); + } + + void end() { + if (scope != null) { + scope.close(); + } + if (span != null) { + span.end(); + } + } + } + + @Override + public Publisher apply(Flowable upstream) { + return Flowable.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + + @Override + public SingleSource apply(Single upstream) { + return Single.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + + @Override + public MaybeSource apply(Maybe upstream) { + return Maybe.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + + @Override + public CompletableSource apply(Completable upstream) { + return Completable.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + } } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index ece1bdad1..9439fe718 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -184,6 +184,11 @@ public void testNestedSpanHierarchy() { } // Verify complete hierarchy + List spans = openTelemetryRule.getSpans(); + // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response + // [testTool]". + assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); + SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -311,6 +316,9 @@ public void testCallLlmSpanLinksToAgentRun() { invokeAgentSpan.end(); } + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans", 2, spans.size()); + SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName("call_llm"); @@ -372,6 +380,31 @@ public void testTraceFlowable() throws InterruptedException { assertTrue(flowableSpanData.hasEnded()); } + @Test + public void testTraceTransformer() throws InterruptedException { + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Flowable flowable = + Flowable.just(1, 2, 3) + .map( + i -> { + assertTrue(Span.current().getSpanContext().isValid()); + return i * 2; + }) + .compose(Tracing.trace("transformer")); + flowable.test().await().assertComplete(); + } finally { + parentSpan.end(); + } + + SpanData parentSpanData = findSpanByName("parent"); + SpanData transformerSpanData = findSpanByName("transformer"); + assertEquals( + parentSpanData.getSpanContext().getSpanId(), + transformerSpanData.getParentSpanContext().getSpanId()); + assertTrue(transformerSpanData.hasEnded()); + } + @Test public void testTraceAgentInvocation() { Span span = tracer.spanBuilder("test").startSpan(); From fe00ef87f9c7cdf3d1005a411055b90cebdd0c98 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 13 Feb 2026 10:41:15 -0800 Subject: [PATCH 62/63] docs: Update comment in Runner PiperOrigin-RevId: 869797377 --- core/src/main/java/com/google/adk/runner/Runner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 2f8201ba1..0ddfdaea1 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -218,7 +218,7 @@ public Runner( } /** - * Creates a new {@code Runner} with a list of plugins and resumability config. + * Creates a new {@code Runner} with a list of plugins. * * @deprecated Use {@link Runner.Builder} instead. */ From 4ac1dd2b6e480fefd4b0a9198b2e69a9c6334c40 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Feb 2026 00:21:54 -0800 Subject: [PATCH 63/63] feat: Adding TODO files for reaching idiomatic java PiperOrigin-RevId: 871158148 --- dev/INTENRAL_TODOS.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 dev/INTENRAL_TODOS.md diff --git a/dev/INTENRAL_TODOS.md b/dev/INTENRAL_TODOS.md new file mode 100644 index 000000000..ec18f572c --- /dev/null +++ b/dev/INTENRAL_TODOS.md @@ -0,0 +1,22 @@ +# Dev TODOs + +This file contains TODOs for the dev ADK module based on +[Recommendations for making ADK Java more idiomatic](http://go/idiomatic-adk-java). + +## Dev UI + +- [ ] **Conditional UI**: Add a configuration property (e.g., + `adk.web.ui.enabled`) to conditionally enable/disable serving Dev UI static + assets (in `AdkWebServer`). +- [ ] **Integration Tests**: Add E2E tests (Selenium/Playwright/HtmlUnit) for + Dev UI to verify interaction between frontend assets and Spring Boot + backend. +- [ ] **Integration Tests**: Test critical paths like loading UI, WebSocket + connection, sending/receiving messages, and rich content handling (images). + +## Production Readiness + +- [ ] **Actuators**: Enable and configure Spring Boot Actuator endpoints for + monitoring and management. +- [ ] **Actuators**: Configure startup and readiness probes for production + environments.