diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 1abd4a94a..f2f55f079 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -6,30 +6,72 @@ labels: '' assignees: '' --- - -** Please make sure you read the contribution guide and file the issues in the rigth place. ** +**Please make sure you read the contribution guide and file the issues in the + right place. ** [Contribution guide.](https://google.github.io/adk-docs/contributing-guide/) -**Describe the bug** +## 🔴 Required Information +*Please ensure all items in this section are completed to allow for efficient +triaging. Requests without complete information may be rejected / deprioritized. +If an item is not applicable to you - please mark it as N/A + +**Describe the Bug:** A clear and concise description of what the bug is. -**To Reproduce** -Steps to reproduce the behavior: +**Steps to Reproduce:** +Please provide a numbered list of steps to reproduce the behavior: 1. Install '...' 2. Run '....' 3. Open '....' -4. See error +4. Provide error or stacktrace -**Expected behavior** +**Expected Behavior:** A clear and concise description of what you expected to happen. -**Screenshots** -If applicable, add screenshots to help explain your problem. +**Observed Behavior:** +What actually happened? Include error messages or crash stack traces here. + +**Environment Details:** + + - ADK Library Version (see maven dependency): + - OS: [e.g., macOS, Linux, Windows] + - TS Version (tsc --version): + +**Model Information:** + + - Which model is being used: (e.g., gemini-2.5-pro) + +--- + +## 🟡 Optional Information +*Providing this information greatly speeds up the resolution process.* + +**Regression:** +Did this work in a previous version of ADK? (Yes/No) If so, which one? + +**Logs:** +Please attach relevant logs. Wrap them in code blocks (```) or attach a +text file. +```text +// Paste logs here +``` + +**Screenshots / Video:** +If applicable, add screenshots or screen recordings to help explain +your problem. + +**Additional Context:** +Add any other context about the problem here. + +**Minimal Reproduction Code:** +Please provide a code snippet or a link to a Gist/repo that isolates the issue. +``` +// Code snippet here +``` -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Java version: - - ADK version(see maven dependency): +**How often has this issue occurred?:** -**Additional context** -Add any other context about the problem here. \ No newline at end of file + - Always (100%) + - Often (50%+) + - Intermittently (<50%) + - Once / Rare \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index facf41f1e..6af25148f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -10,14 +10,39 @@ assignees: '' ** Please make sure you read the contribution guide and file the issues in the right place. ** [Contribution guide.](https://google.github.io/adk-docs/contributing-guide/) -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] +## 🔴 Required Information +*Please ensure all items in this section are completed to allow for efficient +triaging. Requests without complete information may be rejected / deprioritized. +If an item is not applicable to you - please mark it as N/A* -**Describe the solution you'd like** -A clear and concise description of what you want to happen. +### Is your feature request related to a specific problem? +Please describe the problem you are trying to solve. (Ex: "I'm always frustrated +when I have to manually handle X...") -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. +### Describe the Solution You'd Like +A clear and concise description of the feature or API change you want. +Be specific about input/outputs if this involves an API change. -**Additional context** +### Impact on your work +How does this feature impact your work and what are you trying to achieve? +If this is critical for you, tell us if there is a timeline by when you need +this feature. + +### Willingness to contribute +Are you interested in implementing this feature yourself or submitting a PR? +(Yes/No) + +--- + +## 🟡 Recommended Information + +### Describe Alternatives You've Considered +A clear and concise description of any alternative solutions or workarounds +you've considered and why they didn't work for you. + +### Proposed API / Implementation +If you have ideas on how this should look in code, please share a +pseudo-code example. + +### Additional Context Add any other context or screenshots about the feature request here. \ No newline at end of file diff --git a/.github/release-please.yml b/.github/release-please.yml deleted file mode 100644 index 168307d7f..000000000 --- a/.github/release-please.yml +++ /dev/null @@ -1,6 +0,0 @@ -releaseType: maven -handleGHRelease: true -bumpMinorPreMajor: true -extraFiles: - - core/src/main/java/com/google/adk/Version.java - - README.md \ No newline at end of file diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml deleted file mode 100644 index 7fe362257..000000000 --- a/.github/release-trigger.yml +++ /dev/null @@ -1 +0,0 @@ -enabled: true \ No newline at end of file diff --git a/.github/workflows/release-please.yaml b/.github/workflows/release-please.yaml new file mode 100644 index 000000000..6d3142907 --- /dev/null +++ b/.github/workflows/release-please.yaml @@ -0,0 +1,17 @@ +'on': + push: + branches: + - main + workflow_dispatch: {} +permissions: + contents: write + issues: write + pull-requests: write +name: release-please +jobs: + release-please: + runs-on: ubuntu-latest + steps: + - uses: googleapis/release-please-action@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index eeb16e1ff..d9035a579 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -37,8 +37,8 @@ jobs: ${{ runner.os }}-maven-${{ matrix.java-version }}- ${{ runner.os }}-maven- - - name: Compile and test (all) modules with Java ${{ matrix.java-version }} - run: ./mvnw clean test + - name: Package and test (all) modules with Java ${{ matrix.java-version }} + run: ./mvnw -Prelease clean package - name: Detected wrongly formatted files run: git status && git diff --exit-code diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 000000000..600835dbf --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "0.5.0" +} diff --git a/README.md b/README.md index a0bb6b181..1edddff3c 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ debugging, versioning, and deployment anywhere – from your laptop to the cloud If you're using Maven, add the following to your dependencies: - + ```xml @@ -154,7 +154,7 @@ If you're using Maven, add the following to your dependencies: ``` - + To instead use an unreleased version, you could use ; see for an example illustrating this. diff --git a/a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java b/a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java deleted file mode 100644 index bd345ab22..000000000 --- a/a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java +++ /dev/null @@ -1,307 +0,0 @@ -package com.google.adk.a2a; - -import static com.google.common.base.Strings.isNullOrEmpty; -import static java.util.concurrent.TimeUnit.MILLISECONDS; - -import com.google.adk.a2a.converters.ConversationPreprocessor; -import com.google.adk.a2a.converters.RequestConverter; -import com.google.adk.a2a.converters.ResponseConverter; -import com.google.adk.agents.BaseAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.artifacts.InMemoryArtifactService; -import com.google.adk.events.Event; -import com.google.adk.memory.InMemoryMemoryService; -import com.google.adk.runner.Runner; -import com.google.adk.sessions.InMemorySessionService; -import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import io.a2a.spec.Message; -import io.a2a.spec.TextPart; -import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Single; -import java.time.Duration; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import org.jspecify.annotations.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Shared SendMessage execution between HTTP service and other integrations. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -public final class A2ASendMessageExecutor { - private static final Logger logger = LoggerFactory.getLogger(A2ASendMessageExecutor.class); - - @FunctionalInterface - public interface AgentExecutionStrategy { - Single> execute( - String userId, - String sessionId, - Content userContent, - RunConfig runConfig, - String invocationId); - } - - private final InMemorySessionService sessionService; - private final String appName; - @Nullable private final Runner runner; - @Nullable private final Duration agentTimeout; - private static final RunConfig DEFAULT_RUN_CONFIG = - RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); - - public A2ASendMessageExecutor(InMemorySessionService sessionService, String appName) { - this.sessionService = sessionService; - this.appName = appName; - this.runner = null; - this.agentTimeout = null; - } - - public A2ASendMessageExecutor(BaseAgent agent, String appName, Duration agentTimeout) { - InMemorySessionService sessionService = new InMemorySessionService(); - Runner runnerInstance = - new Runner( - agent, - appName, - new InMemoryArtifactService(), - sessionService, - new InMemoryMemoryService()); - this.sessionService = sessionService; - this.appName = appName; - this.runner = runnerInstance; - this.agentTimeout = agentTimeout; - } - - public Single execute( - @Nullable Message request, AgentExecutionStrategy agentExecutionStrategy) { - final String invocationId = UUID.randomUUID().toString(); - final String contextId = resolveContextId(request); - final ImmutableList inputEvents = buildInputEvents(request, invocationId); - - ConversationPreprocessor.PreparedInput prepared = - ConversationPreprocessor.extractHistoryAndUserContent(inputEvents); - - String userId = buildUserId(contextId); - String sessionId = contextId; - - return ensureSessionExistsSingle(userId, sessionId, contextId) - .flatMap( - session -> - processEventsSingle( - session, prepared, userId, sessionId, invocationId, agentExecutionStrategy)) - .map( - resultEvents -> { - final String taskId = resolveTaskId(request); - return ResponseConverter.eventsToMessage(resultEvents, contextId, taskId); - }) - .onErrorReturn( - throwable -> { - logger.error("Error processing A2A request", throwable); - return errorResponse("Internal error: " + throwable.getMessage(), contextId); - }); - } - - public Single execute(@Nullable Message request) { - if (runner == null || agentTimeout == null) { - throw new IllegalStateException( - "Runner-based handle invoked without configured runner or timeout"); - } - return execute(request, this::executeAgentWithTimeout); - } - - private Single ensureSessionExistsSingle( - String userId, String sessionId, String contextId) { - return sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - ConcurrentHashMap initialState = new ConcurrentHashMap<>(); - return sessionService.createSession(appName, userId, initialState, sessionId); - })); - } - - private Completable appendHistoryEvents( - Session session, ConversationPreprocessor.PreparedInput prepared, String invocationId) { - ImmutableList eventsToAppend = - filterNewHistoryEvents(session, prepared.historyEvents, invocationId); - return appendEvents(session, eventsToAppend); - } - - private ImmutableList filterNewHistoryEvents( - Session session, List historyEvents, String invocationId) { - Set existingEventIds = new HashSet<>(); - for (Event existing : session.events()) { - if (existing.id() != null) { - existingEventIds.add(existing.id()); - } - } - - ImmutableList.Builder eventsToAppend = ImmutableList.builder(); - for (Event historyEvent : historyEvents) { - ensureIdentifiers(historyEvent, invocationId); - if (existingEventIds.add(historyEvent.id())) { - eventsToAppend.add(historyEvent); - } - } - return eventsToAppend.build(); - } - - private Completable appendEvents(Session session, ImmutableList events) { - Completable chain = Completable.complete(); - for (Event event : events) { - chain = chain.andThen(sessionService.appendEvent(session, event).ignoreElement()); - } - return chain; - } - - private Single> processEventsSingle( - Session session, - ConversationPreprocessor.PreparedInput prepared, - String userId, - String sessionId, - String invocationId, - AgentExecutionStrategy agentExecutionStrategy) { - Content userContent = - prepared.userContent.orElseGet(A2ASendMessageExecutor::defaultUserContent); - return appendHistoryEvents(session, prepared, invocationId) - .andThen( - agentExecutionStrategy.execute( - userId, sessionId, userContent, DEFAULT_RUN_CONFIG, invocationId)); - } - - private static ImmutableList defaultHelloEvent(String invocationId) { - Event e = - Event.builder() - .id(UUID.randomUUID().toString()) - .invocationId(invocationId) - .author("user") - .content(defaultUserContent()) - .build(); - return ImmutableList.of(e); - } - - private static Content defaultUserContent() { - return Content.builder() - .role("user") - .parts(ImmutableList.of(com.google.genai.types.Part.builder().text("Hello").build())) - .build(); - } - - private static Message errorResponse(String msg, String contextId) { - Message error = - new Message.Builder() - .messageId(UUID.randomUUID().toString()) - .role(Message.Role.AGENT) - .parts(ImmutableList.of(new TextPart("Error: " + msg))) - .build(); - if (contextId != null && !contextId.isEmpty()) { - error.setContextId(contextId); - } - return error; - } - - private Single> executeAgentWithTimeout( - String userId, - String sessionId, - Content userContent, - RunConfig runConfig, - String invocationId) { - if (runner == null || agentTimeout == null) { - throw new IllegalStateException("Runner-based execution invoked without configuration"); - } - - Single> agentResultSingle = - runner - .runAsync(userId, sessionId, userContent, runConfig) - .toList() - .map(events -> ImmutableList.copyOf(events)); - - return agentResultSingle - .timeout(agentTimeout.toMillis(), MILLISECONDS) - .onErrorResumeNext( - throwable -> { - if (isTimeout(throwable)) { - logger.warn( - "Agent execution exceeded {}; returning timeout event", - agentTimeout, - throwable); - return Single.just(ImmutableList.of(createTimeoutEvent(invocationId))); - } - return Single.error(throwable); - }); - } - - private static String resolveContextId(@Nullable Message inbound) { - if (inbound == null || inbound.getContextId() == null || inbound.getContextId().isEmpty()) { - return UUID.randomUUID().toString(); - } - return inbound.getContextId(); - } - - private static String resolveTaskId(@Nullable Message inbound) { - if (inbound != null && inbound.getTaskId() != null && !inbound.getTaskId().isEmpty()) { - return inbound.getTaskId(); - } - return UUID.randomUUID().toString(); - } - - private static ImmutableList buildInputEvents( - @Nullable Message inbound, String invocationId) { - if (inbound == null) { - return defaultHelloEvent(invocationId); - } - return RequestConverter.convertAggregatedA2aMessageToAdkEvents(inbound, invocationId); - } - - private static String buildUserId(String contextId) { - return "user-" + contextId; - } - - private static void ensureIdentifiers(Event event, String invocationId) { - if (isNullOrEmpty(event.id())) { - event.setId(Event.generateEventId()); - } - if (isNullOrEmpty(event.invocationId())) { - event.setInvocationId(invocationId); - } - } - - private static Event createTimeoutEvent(String invocationId) { - return Event.builder() - .id(UUID.randomUUID().toString()) - .invocationId(invocationId) - .author("agent") - .content( - Content.builder() - .role("model") - .parts( - ImmutableList.of( - com.google.genai.types.Part.builder() - .text("Agent execution timed out.") - .build())) - .build()) - .build(); - } - - private static boolean isTimeout(@Nullable Throwable throwable) { - while (throwable != null) { - if (throwable instanceof TimeoutException) { - return true; - } - if (throwable.getClass().getName().endsWith("TimeoutException")) { - return true; - } - throwable = throwable.getCause(); - } - return false; - } -} diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java new file mode 100644 index 000000000..6df01694a --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java @@ -0,0 +1,185 @@ +package com.google.adk.a2a; + +import com.google.adk.a2a.converters.EventConverter; +import com.google.adk.a2a.converters.PartConverter; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.InvalidAgentResponseError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.disposables.CompositeDisposable; +import io.reactivex.rxjava3.disposables.Disposable; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. + * + *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not + * use in production code. + */ +public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { + + private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); + private static final String USER_ID_PREFIX = "A2A_USER_"; + private static final RunConfig DEFAULT_RUN_CONFIG = + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); + + private final Runner runner; + private final Map activeTasks = new ConcurrentHashMap<>(); + + private AgentExecutor(Runner runner) { + this.runner = runner; + } + + /** Builder for {@link AgentExecutor}. */ + public static class Builder { + private Runner runner; + + @CanIgnoreReturnValue + public Builder runner(Runner runner) { + this.runner = runner; + return this; + } + + @CanIgnoreReturnValue + public AgentExecutor build() { + if (runner == null) { + throw new IllegalStateException("Runner must be provided."); + } + return new AgentExecutor(runner); + } + } + + @Override + public void cancel(RequestContext ctx, EventQueue eventQueue) { + TaskUpdater updater = new TaskUpdater(ctx, eventQueue); + updater.cancel(); + cleanupTask(ctx.getTaskId()); + } + + @Override + public void execute(RequestContext ctx, EventQueue eventQueue) { + TaskUpdater updater = new TaskUpdater(ctx, eventQueue); + Message message = ctx.getMessage(); + if (message == null) { + throw new IllegalArgumentException("Message cannot be null"); + } + + // Submits a new task if there is no active task. + if (ctx.getTask() == null) { + updater.submit(); + } + + // Group all reactive work for this task into one container + CompositeDisposable taskDisposables = new CompositeDisposable(); + // Check if the task with the task id is already running, put if absent. + if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { + throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); + } + + EventProcessor p = new EventProcessor(); + Content content = PartConverter.messageToContent(message); + + taskDisposables.add( + prepareSession(ctx, runner.sessionService()) + .flatMapPublisher( + session -> { + updater.startWork(); + return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG); + }) + .subscribe( + event -> { + p.process(event, updater); + }, + error -> { + logger.error("Runner failed with {}", error); + updater.fail(failedMessage(ctx, error)); + cleanupTask(ctx.getTaskId()); + }, + () -> { + updater.complete(); + cleanupTask(ctx.getTaskId()); + })); + } + + private void cleanupTask(String taskId) { + Disposable d = activeTasks.remove(taskId); + if (d != null) { + d.dispose(); // Stops all streams in the CompositeDisposable + } + } + + private String getUserId(RequestContext ctx) { + return USER_ID_PREFIX + ctx.getContextId(); + } + + private Maybe prepareSession(RequestContext ctx, BaseSessionService service) { + return service + .getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty()) + .switchIfEmpty( + Maybe.defer( + () -> { + return service.createSession(runner.appName(), getUserId(ctx)).toMaybe(); + })); + } + + private static Message failedMessage(RequestContext context, Throwable e) { + return new Message.Builder() + .messageId(UUID.randomUUID().toString()) + .contextId(context.getContextId()) + .taskId(context.getTaskId()) + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart(e.getMessage()))) + .build(); + } + + // Processor that will process all events related to the one runner invocation. + private static class EventProcessor { + + // All artifacts related to the invocation should have the same artifact id. + private EventProcessor() { + artifactId = UUID.randomUUID().toString(); + } + + private final String artifactId; + + private void process(Event event, TaskUpdater updater) { + if (event.errorCode().isPresent()) { + throw new InvalidAgentResponseError( + null, // Uses default code -32006 + "Agent returned an error: " + event.errorCode().get(), + null); + } + + ImmutableList> parts = EventConverter.contentToParts(event.content()); + + // Mark all parts as partial if the event is partial. + if (event.partial().orElse(false)) { + parts.forEach( + part -> { + Map metadata = part.getMetadata(); + metadata.put("adk_partial", true); + }); + } + + updater.addArtifact(parts, artifactId, null, ImmutableMap.of()); + } + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index cd8bcefb0..f5b1178c0 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -1,7 +1,10 @@ package com.google.adk.a2a.converters; +import static com.google.common.collect.ImmutableList.toImmutableList; + import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; +import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.a2a.spec.Message; @@ -37,6 +40,16 @@ public enum AggregationMode { EXTERNAL_HANDOFF } + public static ImmutableList> contentToParts(Optional content) { + if (content.isPresent() && content.get().parts().isPresent()) { + return content.get().parts().get().stream() + .map(PartConverter::fromGenaiPart) + .flatMap(Optional::stream) + .collect(toImmutableList()); + } + return ImmutableList.of(); + } + public static Optional convertEventsToA2AMessage(InvocationContext context) { return convertEventsToA2AMessage(context, AggregationMode.AS_IS); } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 0b5ea5503..c6ef06400 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.Content; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -16,6 +17,7 @@ import io.a2a.spec.FilePart; import io.a2a.spec.FileWithBytes; import io.a2a.spec.FileWithUri; +import io.a2a.spec.Message; import io.a2a.spec.TextPart; import java.util.Base64; import java.util.HashMap; @@ -181,6 +183,17 @@ private static Optional convertDataPartToGenAiPart( } } + /** + * Converts an A2A Message to a Google GenAI Content object. + * + * @param message The A2A Message to convert. + * @return The converted Google GenAI Content object. + */ + public static Content messageToContent(Message message) { + ImmutableList parts = toGenaiParts(message.getParts()); + return Content.builder().role("user").parts(parts).build(); + } + /** * Creates an A2A DataPart from a Google GenAI FunctionResponse. * @@ -227,7 +240,7 @@ public static Optional> fromGenaiPart(Part part) { } if (part.text().isPresent()) { - return Optional.of(new TextPart(part.text().get())); + return Optional.of(new TextPart(part.text().get(), new HashMap<>())); } if (part.fileData().isPresent()) { @@ -235,7 +248,7 @@ public static Optional> fromGenaiPart(Part part) { String uri = fileData.fileUri().orElse(null); String mime = fileData.mimeType().orElse(null); String name = fileData.displayName().orElse(null); - return Optional.of(new FilePart(new FileWithUri(mime, name, uri))); + return Optional.of(new FilePart(new FileWithUri(mime, name, uri), new HashMap<>())); } if (part.inlineData().isPresent()) { @@ -244,7 +257,7 @@ public static Optional> fromGenaiPart(Part part) { String encoded = bytes != null ? Base64.getEncoder().encodeToString(bytes) : null; String mime = blob.mimeType().orElse(null); String name = blob.displayName().orElse(null); - return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded))); + return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded), new HashMap<>())); } if (part.functionCall().isPresent() || part.functionResponse().isPresent()) { diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 785ce6f38..61ab84c90 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -163,7 +163,7 @@ public static Message eventToMessage(Event event, String contextId) { * empty optional if the event should be ignored (e.g. if the event is not a final update for * TaskArtifactUpdateEvent or if the message is empty for TaskStatusUpdateEvent). * - * @throws an {@link IllegalArgumentException} if the event type is not supported. + * @throws IllegalArgumentException if the event type is not supported. */ public static Optional clientEventToEvent( ClientEvent event, InvocationContext invocationContext) { @@ -182,7 +182,7 @@ public static Optional clientEventToEvent( * the event is not a final update for TaskArtifactUpdateEvent or if the message is empty for * TaskStatusUpdateEvent. * - * @throws an {@link IllegalArgumentException} if the task update type is not supported. + * @throws IllegalArgumentException if the task update type is not supported. */ private static Optional handleTaskUpdate( TaskUpdateEvent event, InvocationContext context) { diff --git a/a2a/webservice/pom.xml b/a2a/webservice/pom.xml deleted file mode 100644 index deb03fd27..000000000 --- a/a2a/webservice/pom.xml +++ /dev/null @@ -1,67 +0,0 @@ - - - 4.0.0 - - - com.google.adk - google-adk-parent - 0.5.1-SNAPSHOT - ../../pom.xml - - - google-adk-a2a-webservice - jar - - Google ADK A2A Webservice - - - 17 - ${java.version} - - - - - com.google.adk - google-adk-a2a - ${project.version} - - - org.springframework.boot - spring-boot-starter-web - - - org.slf4j - slf4j-api - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.13.0 - - ${java.version} - - - - org.springframework.boot - spring-boot-maven-plugin - ${spring-boot.version} - - - - repackage - - - exec - - - - - - - \ No newline at end of file diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java deleted file mode 100644 index 93e321eb1..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteApplication.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.google.adk.webservice; - -import org.springframework.boot.SpringApplication; -import org.springframework.boot.autoconfigure.SpringBootApplication; -import org.springframework.context.annotation.Import; - -/** - * Entry point for the standalone Spring Boot A2A service. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@SpringBootApplication -@Import(A2ARemoteConfiguration.class) -public class A2ARemoteApplication { - - public static void main(String[] args) { - SpringApplication.run(A2ARemoteApplication.class, args); - } -} diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java deleted file mode 100644 index a3f9b48ac..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteConfiguration.java +++ /dev/null @@ -1,49 +0,0 @@ -package com.google.adk.webservice; - -import com.google.adk.a2a.A2ASendMessageExecutor; -import com.google.adk.agents.BaseAgent; -import java.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.ComponentScan; -import org.springframework.context.annotation.Configuration; - -/** - * Registers the transport-only A2A webservice stack. - * - *

Importers must supply a {@link BaseAgent} bean. The agent remains opaque to this module so the - * transport can be reused across applications. - * - *

TODO: - * - *

- * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@Configuration -@ComponentScan(basePackages = "com.google.adk.webservice") -public class A2ARemoteConfiguration { - - private static final Logger logger = LoggerFactory.getLogger(A2ARemoteConfiguration.class); - private static final String DEFAULT_APP_NAME = "a2a-remote-service"; - private static final long DEFAULT_TIMEOUT_SECONDS = 15L; - - @Bean - public A2ASendMessageExecutor a2aSendMessageExecutor( - BaseAgent agent, - @Value("${a2a.remote.appName:" + DEFAULT_APP_NAME + "}") String appName, - @Value("${a2a.remote.timeoutSeconds:" + DEFAULT_TIMEOUT_SECONDS + "}") long timeoutSeconds) { - logger.info( - "Initializing A2A send message executor for appName {} with timeout {}s", - appName, - timeoutSeconds); - return new A2ASendMessageExecutor(agent, appName, Duration.ofSeconds(timeoutSeconds)); - } -} diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java deleted file mode 100644 index a0fe5b0cc..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteController.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.google.adk.webservice; - -import io.a2a.spec.SendMessageRequest; -import io.a2a.spec.SendMessageResponse; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; - -/** - * REST controller exposing an A2A-compliant JSON-RPC endpoint backed by a local ADK runner. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@RestController -@RequestMapping("/a2a/remote") -public class A2ARemoteController { - - private static final Logger logger = LoggerFactory.getLogger(A2ARemoteController.class); - - private final A2ARemoteService service; - - public A2ARemoteController(A2ARemoteService service) { - this.service = service; - } - - @PostMapping( - path = "/v1/message:send", - consumes = "application/json", - produces = "application/json") - public SendMessageResponse sendMessage(@RequestBody SendMessageRequest request) { - logger.debug("Received remote A2A request: {}", request); - SendMessageResponse response = service.handle(request); - logger.debug("Responding with remote A2A payload: {}", response); - return response; - } -} diff --git a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java b/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java deleted file mode 100644 index 803774568..000000000 --- a/a2a/webservice/src/main/java/com/google/adk/webservice/A2ARemoteService.java +++ /dev/null @@ -1,93 +0,0 @@ -package com.google.adk.webservice; - -import com.google.adk.a2a.A2ASendMessageExecutor; -import com.google.adk.a2a.converters.ResponseConverter; -import io.a2a.spec.JSONRPCError; -import io.a2a.spec.Message; -import io.a2a.spec.MessageSendParams; -import io.a2a.spec.SendMessageRequest; -import io.a2a.spec.SendMessageResponse; -import java.util.List; -import java.util.UUID; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.stereotype.Service; - -/** - * Core service that bridges the A2A JSON-RPC sendMessage API to a local ADK runner. - * - *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not - * use in production code. - */ -@Service -public class A2ARemoteService { - - private static final Logger logger = LoggerFactory.getLogger(A2ARemoteService.class); - private static final int ERROR_CODE_INVALID_PARAMS = -32602; - private static final int ERROR_CODE_INTERNAL_ERROR = -32603; - - private final A2ASendMessageExecutor executor; - - public A2ARemoteService(A2ASendMessageExecutor executor) { - this.executor = executor; - } - - public SendMessageResponse handle(SendMessageRequest request) { - if (request == null) { - logger.warn("Received null SendMessageRequest"); - return invalidParamsResponse(null, "Request body is missing"); - } - - MessageSendParams params = request.getParams(); - if (params == null) { - logger.warn("SendMessageRequest {} missing params", request.getId()); - return invalidParamsResponse(request, "Request params are missing"); - } - - Message inbound = params.message(); - if (inbound == null) { - logger.warn("SendMessageRequest {} missing message payload", request.getId()); - return invalidParamsResponse(request, "Request message payload is missing"); - } - - boolean generatedContext = inbound.getContextId() == null || inbound.getContextId().isEmpty(); - Message normalized = ensureContextId(inbound); - if (generatedContext) { - logger.debug("Incoming request lacked contextId; generated {}", normalized.getContextId()); - } - - try { - Message result = executor.execute(normalized).blockingGet(); - if (result == null) { - result = - ResponseConverter.eventsToMessage( - List.of(), normalized.getContextId(), normalized.getTaskId()); - } - - logger.debug("Returning A2A response for context {}", normalized.getContextId()); - return new SendMessageResponse(request.getId(), result); - } catch (RuntimeException e) { - logger.error("Failed to process remote A2A request", e); - return errorResponse(request, e); - } - } - - private static Message ensureContextId(Message message) { - if (message.getContextId() != null && !message.getContextId().isEmpty()) { - return message; - } - return new Message.Builder(message).contextId(UUID.randomUUID().toString()).build(); - } - - private static SendMessageResponse invalidParamsResponse( - SendMessageRequest request, String reason) { - JSONRPCError error = new JSONRPCError(ERROR_CODE_INVALID_PARAMS, reason, null); - return new SendMessageResponse(request != null ? request.getId() : null, error); - } - - private static SendMessageResponse errorResponse(SendMessageRequest request, Throwable error) { - String message = "Internal error processing sendMessage request"; - JSONRPCError jsonrpcError = new JSONRPCError(ERROR_CODE_INTERNAL_ERROR, message, null); - return new SendMessageResponse(request != null ? request.getId() : null, jsonrpcError); - } -} diff --git a/contrib/samples/a2a_basic/A2AAgent.java b/contrib/samples/a2a_basic/A2AAgent.java index fa4932cf9..e4e79a4eb 100644 --- a/contrib/samples/a2a_basic/A2AAgent.java +++ b/contrib/samples/a2a_basic/A2AAgent.java @@ -1,8 +1,5 @@ package com.example.a2a_basic; -import java.util.ArrayList; -import java.util.Random; - import com.google.adk.a2a.RemoteA2AAgent; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.LlmAgent; @@ -10,7 +7,6 @@ import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - import io.a2a.client.Client; import io.a2a.client.config.ClientConfig; import io.a2a.client.http.A2ACardResolver; @@ -18,6 +14,8 @@ import io.a2a.client.transport.jsonrpc.JSONRPCTransport; import io.a2a.client.transport.jsonrpc.JSONRPCTransportConfig; import io.a2a.spec.AgentCard; +import java.util.ArrayList; +import java.util.Random; /** Provides local roll logic plus a remote A2A agent for the demo. */ public final class A2AAgent { diff --git a/contrib/samples/a2a_basic/A2AAgentRun.java b/contrib/samples/a2a_basic/A2AAgentRun.java index 12d515f62..b4a63b026 100644 --- a/contrib/samples/a2a_basic/A2AAgentRun.java +++ b/contrib/samples/a2a_basic/A2AAgentRun.java @@ -1,11 +1,5 @@ package com.example.a2a_basic; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.TimeUnit; - import com.google.adk.agents.BaseAgent; import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.InMemoryArtifactService; @@ -15,8 +9,12 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; - import io.reactivex.rxjava3.core.Flowable; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; /** Main class to demonstrate running the A2A agent with sequential inputs. */ public final class A2AAgentRun { diff --git a/contrib/samples/a2a_remote/README.md b/contrib/samples/a2a_remote/README.md deleted file mode 100644 index d1d2601ca..000000000 --- a/contrib/samples/a2a_remote/README.md +++ /dev/null @@ -1,70 +0,0 @@ -# A2A Remote Prime Service Sample - -This sample starts a standalone Spring Boot service that exposes the -`remote_prime_agent` via the shared A2A webservice module -(`google-adk-a2a-webservice`). It behaves like a third‑party service that -implements the A2A JSON‑RPC contract and can be used by the ADK client (for -example, the `a2a_basic` demo) as its remote endpoint. - -## Running the service - -```bash -cd google_adk -mvn -f contrib/samples/a2a_remote/pom.xml package - -GOOGLE_GENAI_USE_VERTEXAI=FALSE \ -GOOGLE_API_KEY= \ -mvn -f contrib/samples/a2a_remote/pom.xml exec:java -``` - -`RemoteA2AApplication` imports the reusable controller/service from -`google-adk-a2a-webservice`, so the server listens on -`http://localhost:8080/a2a/remote/v1/message:send` by default. Override the -port with `-Dspring-boot.run.arguments=--server.port=` when running via -`spring-boot:run` if you need to avoid collisions. - -``` -POST /a2a/remote/v1/message:send -Content-Type: application/json -``` - -and accepts standard A2A JSON‑RPC payloads (`SendMessageRequest`). The -response is a `SendMessageResponse` that contains either a `Message` or a -`Task` in the `result` field. Spring Boot logs the request/response lifecycle -to the console; add your preferred logging configuration if you need -persistent logs. - -## Agent implementation - -- `remote_prime_agent/Agent.java` hosts the LLM agent that checks whether - numbers are prime (lifted from the Stubby demo). The model name defaults - to `gemini-2.5-pro`; set `GOOGLE_API_KEY` before running. -- `RemoteA2AApplication` bootstraps the service by importing - `A2ARemoteConfiguration` and publishing the prime `BaseAgent` bean. The shared - configuration consumes that bean to create the `A2ASendMessageExecutor`. - -## Sample request - -```bash -curl -X POST http://localhost:8080/a2a/remote/v1/message:send \ - -H 'Content-Type: application/json' \ - -d '{ - "jsonrpc": "2.0", - "id": "demo-123", - "method": "message/send", - "params": { - "message": { - "role": "user", - "messageId": "msg-1", - "contextId": "ctx-1", - "parts": [ - {"kind": "text", "text": "Check if 17 is prime"} - ] - }, - "metadata": {} - } - }' -``` - -The response contains the prime check result, and the interaction is logged in -the application console. diff --git a/contrib/samples/a2a_remote/pom.xml b/contrib/samples/a2a_remote/pom.xml deleted file mode 100644 index 59d9cf01e..000000000 --- a/contrib/samples/a2a_remote/pom.xml +++ /dev/null @@ -1,139 +0,0 @@ - - - 4.0.0 - - - com.google.adk - google-adk-parent - 0.5.1-SNAPSHOT - ../../../pom.xml - - - google-adk-sample-a2a-remote - Google ADK - Sample - A2A Remote Prime Service - Spring Boot service that exposes the remote prime-check agent over the A2A REST interface. - jar - - - 3.3.4 - 17 - 0.8 - com.google.adk.samples.a2a_remote.RemoteA2AApplication - - - - - - org.springframework.boot - spring-boot-dependencies - ${spring-boot.version} - pom - import - - - - - - - org.springframework.boot - spring-boot-starter-web - - - - com.google.adk - google-adk - ${project.version} - - - - com.google.adk - google-adk-a2a - ${project.version} - - - - com.google.adk - google-adk-a2a-webservice - ${project.version} - - - - com.google.flogger - flogger - ${flogger.version} - - - com.google.flogger - google-extensions - ${flogger.version} - - - com.google.flogger - flogger-system-backend - ${flogger.version} - - - - org.springframework.boot - spring-boot-starter-test - test - - - - com.google.truth - truth - ${truth.version} - test - - - - - - - org.springframework.boot - spring-boot-maven-plugin - ${spring-boot.version} - - - org.codehaus.mojo - build-helper-maven-plugin - 3.6.0 - - - add-source - generate-sources - - add-source - - - - . - - - - - - - org.apache.maven.plugins - maven-source-plugin - - - **/*.jar - target/** - - - - - org.codehaus.mojo - exec-maven-plugin - 3.2.0 - - ${exec.mainClass} - runtime - - - - - \ No newline at end of file diff --git a/contrib/samples/a2a_remote/remote_prime_agent/Agent.java b/contrib/samples/a2a_remote/remote_prime_agent/Agent.java deleted file mode 100644 index a0072e8e3..000000000 --- a/contrib/samples/a2a_remote/remote_prime_agent/Agent.java +++ /dev/null @@ -1,101 +0,0 @@ -package com.google.adk.samples.a2a_remote.remote_prime_agent; - -import static java.util.stream.Collectors.joining; - -import com.google.adk.agents.LlmAgent; -import com.google.adk.tools.FunctionTool; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.flogger.GoogleLogger; -import io.reactivex.rxjava3.core.Maybe; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -/** Agent that can check whether numbers are prime. */ -public final class Agent { - - private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); - - public static ImmutableMap checkPrime(List nums) { - logger.atInfo().log("checkPrime called with nums=%s", nums); - Set primes = new HashSet<>(); - for (int num : nums) { - if (num <= 1) { - continue; - } - boolean isPrime = true; - for (int i = 2; i <= Math.sqrt(num); i++) { - if (num % i == 0) { - isPrime = false; - break; - } - } - if (isPrime) { - primes.add(num); - } - } - String result; - if (primes.isEmpty()) { - result = "No prime numbers found."; - } else if (primes.size() == 1) { - int only = primes.iterator().next(); - // Per request: singular phrasing without article - result = only + " is prime number."; - } else { - result = primes.stream().map(String::valueOf).collect(joining(", ")) + " are prime numbers."; - } - logger.atInfo().log("checkPrime result=%s", result); - return ImmutableMap.of("result", result); - } - - public static final LlmAgent ROOT_AGENT = - LlmAgent.builder() - .model("gemini-2.5-pro") - .name("check_prime_agent") - .description("check prime agent that can check whether numbers are prime.") - .instruction( - """ - You check whether numbers are prime. - - If the last user message contains numbers, call checkPrime exactly once with exactly - those integers as a list (e.g., [2]). Never add other numbers. Do not ask for - clarification. Return only the tool's result. - - Always pass a list of integers to the tool (use a single-element list for one - number). Never pass strings. - """) - // Log the exact contents passed to the LLM request for verification - .beforeModelCallback( - (callbackContext, llmRequest) -> { - try { - logger.atInfo().log( - "Invocation events (count=%d): %s", - callbackContext.events().size(), callbackContext.events()); - } catch (Throwable t) { - logger.atWarning().withCause(t).log("BeforeModel logging error"); - } - return Maybe.empty(); - }) - .afterModelCallback( - (callbackContext, llmResponse) -> { - try { - String content = - llmResponse.content().map(Object::toString).orElse(""); - logger.atInfo().log("AfterModel content=%s", content); - llmResponse - .errorMessage() - .ifPresent( - error -> - logger.atInfo().log( - "AfterModel errorMessage=%s", error.replace("\n", "\\n"))); - } catch (Throwable t) { - logger.atWarning().withCause(t).log("AfterModel logging error"); - } - return Maybe.empty(); - }) - .tools(ImmutableList.of(FunctionTool.create(Agent.class, "checkPrime"))) - .build(); - - private Agent() {} -} diff --git a/contrib/samples/a2a_remote/remote_prime_agent/agent.json b/contrib/samples/a2a_remote/remote_prime_agent/agent.json deleted file mode 100644 index 87f2d9ecc..000000000 --- a/contrib/samples/a2a_remote/remote_prime_agent/agent.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "capabilities": {}, - "defaultInputModes": ["text/plain"], - "defaultOutputModes": ["application/json"], - "description": "An agent specialized in checking whether numbers are prime. It can efficiently determine the primality of individual numbers or lists of numbers.", - "name": "check_prime_agent", - "skills": [ - { - "id": "prime_checking", - "name": "Prime Number Checking", - "description": "Check if numbers in a list are prime using efficient mathematical algorithms", - "tags": ["mathematical", "computation", "prime", "numbers"] - } - ], - "url": "http://localhost:8080/a2a/prime_agent", - "version": "1.0.0" -} diff --git a/contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java b/contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java deleted file mode 100644 index 53be8d1d0..000000000 --- a/contrib/samples/a2a_remote/src/main/java/com/google/adk/samples/a2a_remote/RemoteA2AApplication.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.google.adk.samples.a2a_remote; - -import com.google.adk.agents.BaseAgent; -import com.google.adk.samples.a2a_remote.remote_prime_agent.Agent; -import com.google.adk.webservice.A2ARemoteConfiguration; -import org.springframework.boot.SpringApplication; -import org.springframework.boot.autoconfigure.SpringBootApplication; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Import; - -/** Spring Boot entry point that wires the shared A2A webservice with the prime demo agent. */ -@SpringBootApplication -@Import(A2ARemoteConfiguration.class) -public class RemoteA2AApplication { - - public static void main(String[] args) { - SpringApplication.run(RemoteA2AApplication.class, args); - } - - @Bean - public BaseAgent primeAgent() { - return Agent.ROOT_AGENT; - } -} diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index fa5d6dfae..580b10de9 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -17,7 +17,6 @@ a2a_basic - a2a_remote configagent helloworld mcpfilesystem diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java index a0db38550..6843c8eaa 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/SpringAIIntegrationTest.java @@ -123,7 +123,8 @@ public ChatResponse call(Prompt prompt) { // Subsequent calls - provide final answer message = new AssistantMessage( - "The weather in Paris is beautiful and sunny with temperatures from 10°C in the morning up to 24°C in the afternoon."); + "The weather in Paris is beautiful and sunny with temperatures from 10°C in" + + " the morning up to 24°C in the afternoon."); } Generation generation = new Generation(message); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java index 1e2ac82f6..f21b07ae9 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java @@ -145,9 +145,9 @@ void testAgentWithToolsAndRealApi() { .model(new SpringAI(anthropicModel, CLAUDE_MODEL)) .instruction( """ - You are a helpful assistant. - When asked about weather, you MUST use the getWeatherInfo function to get current conditions. - """) + You are a helpful assistant. + When asked about weather, you MUST use the getWeatherInfo function to get current conditions. + """) .tools(FunctionTool.create(WeatherTools.class, "getWeatherInfo")) .build(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java index 47f3d3b11..5414cdf99 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/GeminiApiIntegrationTest.java @@ -157,9 +157,9 @@ void testAgentWithToolsAndRealApi() { .model(new SpringAI(geminiModel, GEMINI_MODEL)) .instruction( """ - You are a helpful assistant. - When asked about weather, you MUST use the getWeatherInfo function to get current conditions. - """) + You are a helpful assistant. + When asked about weather, you MUST use the getWeatherInfo function to get current conditions. + """) .tools(FunctionTool.create(WeatherTools.class, "getWeatherInfo")) .build(); diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java index 894ffaba6..eb956f828 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/OpenAiApiIntegrationTest.java @@ -133,9 +133,9 @@ void testAgentWithToolsAndRealApi() { .model(new SpringAI(openAiModel, GPT_MODEL)) .instruction( """ - You are a helpful assistant. - When asked about weather, use the getWeatherInfo function to get current conditions. - """) + You are a helpful assistant. + When asked about weather, use the getWeatherInfo function to get current conditions. + """) .tools(FunctionTool.create(WeatherTools.class, "getWeatherInfo")) .build(); diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index fbc73039e..8b10341ac 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "0.5.0"; + public static final String JAVA_ADK_VERSION = "0.5.0"; // x-release-please-released-version private Version() {} } diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 948d5ebac..226e61abe 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -16,7 +16,9 @@ package com.google.adk.agents; +import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -27,21 +29,26 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; -import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.Tracer; -import io.opentelemetry.context.Context; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.function.Function; +import java.util.regex.Pattern; import java.util.stream.Stream; import org.jspecify.annotations.Nullable; /** Base class for all agents. */ public abstract class BaseAgent { + // Pattern for valid agent names. + private static final String IDENTIFIER_REGEX = "^_?[a-zA-Z0-9]*([. _-][a-zA-Z0-9]+)*$"; + private static final Pattern IDENTIFIER_PATTERN = Pattern.compile(IDENTIFIER_REGEX); + /** The agent's name. Must be a unique identifier within the agent tree. */ private final String name; @@ -57,10 +64,10 @@ public abstract class BaseAgent { */ private BaseAgent parentAgent; - private final List subAgents; + private final ImmutableList subAgents; - private final Optional> beforeAgentCallback; - private final Optional> afterAgentCallback; + private final ImmutableList beforeAgentCallback; + private final ImmutableList afterAgentCallback; /** * Creates a new BaseAgent. @@ -76,15 +83,23 @@ public abstract class BaseAgent { public BaseAgent( String name, String description, - List subAgents, + @Nullable List subAgents, @Nullable List beforeAgentCallback, @Nullable List afterAgentCallback) { + validateAgentName(name); this.name = name; this.description = description; this.parentAgent = null; - this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); - this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); + this.subAgents = (subAgents != null) ? ImmutableList.copyOf(subAgents) : ImmutableList.of(); + validateSubAgents(this.name, this.subAgents); + this.beforeAgentCallback = + (beforeAgentCallback != null) + ? ImmutableList.copyOf(beforeAgentCallback) + : ImmutableList.of(); + this.afterAgentCallback = + (afterAgentCallback != null) + ? ImmutableList.copyOf(afterAgentCallback) + : ImmutableList.of(); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -92,6 +107,68 @@ public BaseAgent( } } + /** + * Closes all sub-agents. + * + * @return a {@link Completable} that completes when all sub-agents are closed. + */ + public Completable close() { + List completables = new ArrayList<>(); + this.subAgents.forEach(subAgent -> completables.add(subAgent.close())); + return Completable.mergeDelayError(completables); + } + + /** + * Validates the agent name. + * + * @param name The agent name to validate. + * @throws IllegalArgumentException if the agent name is null, empty, or does not match the + * identifier pattern. + */ + private static void validateAgentName(String name) { + if (isNullOrEmpty(name)) { + throw new IllegalArgumentException("Agent name cannot be null or empty."); + } + if (!IDENTIFIER_PATTERN.matcher(name).matches()) { + throw new IllegalArgumentException( + format("Agent name '%s' does not match regex '%s'.", name, IDENTIFIER_REGEX)); + } + if (name.equals("user")) { + throw new IllegalArgumentException( + "Agent name cannot be 'user'; reserved for end-user input."); + } + } + + /** + * Validates the sub-agents. + * + * @param name The name of the parent agent. + * @param subAgents The list of sub-agents to validate. + * @throws IllegalArgumentException if the sub-agents have duplicate names. + */ + private static void validateSubAgents( + String name, @Nullable List subAgents) { + if (subAgents == null) { + return; + } + HashSet subAgentNames = new HashSet<>(); + HashSet duplicateSubAgentNames = new HashSet<>(); + for (BaseAgent subAgent : subAgents) { + String subAgentName = subAgent.name(); + // NOTE: Mocked agents have null names because BaseAgent.name() is a final method that + // cannot be mocked. + if (subAgentName != null && !subAgentNames.add(subAgentName)) { + duplicateSubAgentNames.add(subAgentName); + } + } + if (!duplicateSubAgentNames.isEmpty()) { + throw new IllegalArgumentException( + format( + "Agent named '%s' has sub-agents with duplicate names: %s. Sub-agents: %s", + name, duplicateSubAgentNames, subAgents)); + } + } + /** * Gets the agent's unique name. * @@ -144,38 +221,38 @@ public BaseAgent rootAgent() { /** * Finds an agent (this or descendant) by name. * - * @return the agent or descendant with the given name, or {@code null} if not found. + * @return an {@link Optional} containing the agent or descendant with the given name, or {@link + * Optional#empty()} if not found. */ - public BaseAgent findAgent(String name) { + public Optional findAgent(String name) { if (this.name().equals(name)) { - return this; + return Optional.of(this); } return findSubAgent(name); } - /** Recursively search sub agent by name. */ - public @Nullable BaseAgent findSubAgent(String name) { - for (BaseAgent subAgent : subAgents) { - if (subAgent.name().equals(name)) { - return subAgent; - } - BaseAgent result = subAgent.findSubAgent(name); - if (result != null) { - return result; - } - } - return null; + /** + * Recursively search sub agent by name. + * + * @return an {@link Optional} containing the sub agent with the given name, or {@link + * Optional#empty()} if not found. + */ + public Optional findSubAgent(String name) { + return subAgents.stream() + .map(subAgent -> subAgent.findAgent(name)) + .flatMap(Optional::stream) + .findFirst(); } public List subAgents() { return subAgents; } - public Optional> beforeAgentCallback() { + public ImmutableList beforeAgentCallback() { return beforeAgentCallback; } - public Optional> afterAgentCallback() { + public ImmutableList afterAgentCallback() { return afterAgentCallback; } @@ -184,8 +261,8 @@ public Optional> afterAgentCallback() { * *

This method is only for use by Agent Development Kit. */ - public List canonicalBeforeAgentCallbacks() { - return beforeAgentCallback.orElse(ImmutableList.of()); + public ImmutableList canonicalBeforeAgentCallbacks() { + return beforeAgentCallback; } /** @@ -193,8 +270,8 @@ public List canonicalBeforeAgentCallbacks() { * *

This method is only for use by Agent Development Kit. */ - public List canonicalAfterAgentCallbacks() { - return afterAgentCallback.orElse(ImmutableList.of()); + public ImmutableList canonicalAfterAgentCallbacks() { + return afterAgentCallback; } /** @@ -208,9 +285,10 @@ private InvocationContext createInvocationContext(InvocationContext parentContex InvocationContext.Builder builder = parentContext.toBuilder(); builder.agent(this); // Check for branch to be truthy (not None, not empty string), - if (parentContext.branch().filter(s -> !s.isEmpty()).isPresent()) { - builder.branch(parentContext.branch().get() + "." + name()); - } + parentContext + .branch() + .filter(s -> !s.isEmpty()) + .ifPresent(branch -> builder.branch(branch + "." + name())); return builder.build(); } @@ -221,48 +299,50 @@ private InvocationContext createInvocationContext(InvocationContext parentContex * @return stream of agent-generated events. */ public Flowable runAsync(InvocationContext parentContext) { - Tracer tracer = Tracing.getTracer(); + return run(parentContext, this::runAsyncImpl); + } + + /** + * Runs the agent with the given implementation. + * + * @param parentContext Parent context to inherit. + * @param runImplementation The agent-specific logic to run. + * @return stream of agent-generated events. + */ + private Flowable run( + InvocationContext parentContext, + Function> runImplementation) { return Flowable.defer( () -> { - Span span = - tracer - .spanBuilder("agent_run [" + name() + "]") - .setParent(Context.current()) - .startSpan(); - Context spanContext = Context.current().with(span); - InvocationContext invocationContext = createInvocationContext(parentContext); - return Tracing.traceFlowable( - spanContext, - span, - () -> - callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), - beforeAgentCallback.orElse(ImmutableList.of())), - invocationContext) - .flatMapPublisher( - beforeEventOpt -> { - if (invocationContext.endInvocation()) { - return Flowable.fromOptional(beforeEventOpt); - } - - Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); - Flowable mainEvents = - Flowable.defer(() -> runAsyncImpl(invocationContext)); - Flowable afterEvents = - Flowable.defer( - () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), - afterAgentCallback.orElse(ImmutableList.of())), - invocationContext) - .flatMapPublisher(Flowable::fromOptional)); - - return Flowable.concat(beforeEvents, mainEvents, afterEvents); - })); + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEventOpt -> { + if (invocationContext.endInvocation()) { + return Flowable.fromOptional(beforeEventOpt); + } + + Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); + Flowable mainEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)); + Flowable afterEvents = + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::fromOptional)); + + return Flowable.concat(beforeEvents, mainEvents, afterEvents); + }) + .compose( + Tracing.traceAgent( + "invoke_agent " + name(), name(), description(), invocationContext)); }); } @@ -274,11 +354,8 @@ public Flowable runAsync(InvocationContext parentContext) { */ private ImmutableList>> beforeCallbacksToFunctions( Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) - .collect(toImmutableList()); + return callbacksToFunctions( + ctx -> pluginManager.beforeAgentCallback(this, ctx), callbacks, c -> c::call); } /** @@ -289,10 +366,15 @@ private ImmutableList>> beforeCallbacks */ private ImmutableList>> afterCallbacksToFunctions( Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) + return callbacksToFunctions( + ctx -> pluginManager.afterAgentCallback(this, ctx), callbacks, c -> c::call); + } + + private ImmutableList>> callbacksToFunctions( + Function> pluginCallback, + List callbacks, + Function>> mapper) { + return Stream.concat(Stream.of(pluginCallback), callbacks.stream().map(mapper)) .collect(toImmutableList()); } @@ -306,7 +388,7 @@ private ImmutableList>> afterCallbacksT private Single> callCallback( List>> agentCallbacks, InvocationContext invocationContext) { - if (agentCallbacks == null || agentCallbacks.isEmpty()) { + if (agentCallbacks.isEmpty()) { return Single.just(Optional.empty()); } @@ -361,20 +443,7 @@ private Single> callCallback( * @return stream of agent-generated events. */ public Flowable runLive(InvocationContext parentContext) { - Tracer tracer = Tracing.getTracer(); - return Flowable.defer( - () -> { - Span span = - tracer - .spanBuilder("agent_run [" + name() + "]") - .setParent(Context.current()) - .startSpan(); - Context spanContext = Context.current().with(span); - - InvocationContext invocationContext = createInvocationContext(parentContext); - - return Tracing.traceFlowable(spanContext, span, () -> runLiveImpl(invocationContext)); - }); + return run(parentContext, this::runLiveImpl); } /** @@ -446,8 +515,7 @@ public B subAgents(List subAgents) { @CanIgnoreReturnValue public B subAgents(BaseAgent... subAgents) { - this.subAgents = ImmutableList.copyOf(subAgents); - return self(); + return subAgents(ImmutableList.copyOf(subAgents)); } @CanIgnoreReturnValue diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java b/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java index e38895afb..40ed58937 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgentConfig.java @@ -16,6 +16,7 @@ package com.google.adk.agents; +import com.google.common.collect.ImmutableList; import java.util.List; /** @@ -27,11 +28,11 @@ public class BaseAgentConfig { private String name; private String description = ""; private String agentClass; - private List subAgents; + private ImmutableList subAgents = ImmutableList.of(); // Callback configuration (names resolved via ComponentRegistry) - private List beforeAgentCallbacks; - private List afterAgentCallbacks; + private ImmutableList beforeAgentCallbacks = ImmutableList.of(); + private ImmutableList afterAgentCallbacks = ImmutableList.of(); /** Reference to a callback stored in the ComponentRegistry. */ public static class CallbackRef { @@ -131,27 +132,33 @@ public String agentClass() { return agentClass; } - public List subAgents() { + public ImmutableList subAgents() { return subAgents; } public void setSubAgents(List subAgents) { - this.subAgents = subAgents; + this.subAgents = subAgents == null ? ImmutableList.of() : ImmutableList.copyOf(subAgents); } - public List beforeAgentCallbacks() { + public ImmutableList beforeAgentCallbacks() { return beforeAgentCallbacks; } public void setBeforeAgentCallbacks(List beforeAgentCallbacks) { - this.beforeAgentCallbacks = beforeAgentCallbacks; + this.beforeAgentCallbacks = + beforeAgentCallbacks == null + ? ImmutableList.of() + : ImmutableList.copyOf(beforeAgentCallbacks); } - public List afterAgentCallbacks() { + public ImmutableList afterAgentCallbacks() { return afterAgentCallbacks; } public void setAfterAgentCallbacks(List afterAgentCallbacks) { - this.afterAgentCallbacks = afterAgentCallbacks; + this.afterAgentCallbacks = + afterAgentCallbacks == null + ? ImmutableList.of() + : ImmutableList.copyOf(afterAgentCallbacks); } } diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index f7bbdcdbe..49298451b 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -31,6 +31,7 @@ public class CallbackContext extends ReadonlyContext { protected EventActions eventActions; private final State state; + private final String eventId; /** * Initializes callback context. @@ -39,9 +40,22 @@ public class CallbackContext extends ReadonlyContext { * @param eventActions Callback event actions. */ public CallbackContext(InvocationContext invocationContext, EventActions eventActions) { + this(invocationContext, eventActions, null); + } + + /** + * Initializes callback context. + * + * @param invocationContext Current invocation context. + * @param eventActions Callback event actions. + * @param eventId The ID of the event associated with this context. + */ + public CallbackContext( + InvocationContext invocationContext, EventActions eventActions, String eventId) { super(invocationContext); this.eventActions = eventActions != null ? eventActions : EventActions.builder().build(); this.state = new State(invocationContext.session().state(), this.eventActions.stateDelta()); + this.eventId = eventId; } /** Returns the delta-aware state of the current callback. */ @@ -55,6 +69,11 @@ public EventActions eventActions() { return eventActions; } + /** Returns the ID of the event associated with this context. */ + public String eventId() { + return eventId; + } + /** * Lists the filenames of the artifacts attached to the current session. * @@ -115,7 +134,7 @@ public Completable saveArtifact(String filename, Part artifact) { invocationContext.session().id(), filename, artifact) - .doOnSuccess(unusedVersion -> this.eventActions.artifactDelta().put(filename, artifact)) + .doOnSuccess(version -> this.eventActions.artifactDelta().put(filename, version)) .ignoreElement(); } } diff --git a/core/src/main/java/com/google/adk/agents/CallbackUtil.java b/core/src/main/java/com/google/adk/agents/CallbackUtil.java index 728fd1d5a..11740ae9c 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackUtil.java +++ b/core/src/main/java/com/google/adk/agents/CallbackUtil.java @@ -26,7 +26,8 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import org.jspecify.annotations.Nullable; +import java.util.function.Function; +import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,65 +38,62 @@ public final class CallbackUtil { /** * Normalizes before-agent callbacks. * - * @param beforeAgentCallback Callback list (sync or async). - * @return normalized async callbacks, or null if input is null. + * @param beforeAgentCallbacks Callback list (sync or async). + * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static @Nullable ImmutableList getBeforeAgentCallbacks( - List beforeAgentCallback) { - if (beforeAgentCallback == null) { - return null; - } else if (beforeAgentCallback.isEmpty()) { - return ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeAgentCallbackBase callback : beforeAgentCallback) { - if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { - builder.add(beforeAgentCallbackInstance); - } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { - builder.add( - (callbackContext) -> - Maybe.fromOptional(beforeAgentCallbackSyncInstance.call(callbackContext))); - } else { - logger.warn( - "Invalid beforeAgentCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - return builder.build(); - } + public static ImmutableList getBeforeAgentCallbacks( + List beforeAgentCallbacks) { + return getCallbacks( + beforeAgentCallbacks, + BeforeAgentCallback.class, + BeforeAgentCallbackSync.class, + sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))), + "beforeAgentCallbacks"); } /** * Normalizes after-agent callbacks. * * @param afterAgentCallback Callback list (sync or async). - * @return normalized async callbacks, or null if input is null. + * @return normalized async callbacks, or empty list if input is null. */ @CanIgnoreReturnValue - public static @Nullable ImmutableList getAfterAgentCallbacks( + public static ImmutableList getAfterAgentCallbacks( List afterAgentCallback) { - if (afterAgentCallback == null) { - return null; - } else if (afterAgentCallback.isEmpty()) { + return getCallbacks( + afterAgentCallback, + AfterAgentCallback.class, + AfterAgentCallbackSync.class, + sync -> (callbackContext -> Maybe.fromOptional(sync.call(callbackContext))), + "afterAgentCallback"); + } + + private static ImmutableList getCallbacks( + List callbacks, + Class asyncClass, + Class syncClass, + Function converter, + String callbackTypeForLogging) { + if (callbacks == null) { return ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (AfterAgentCallbackBase callback : afterAgentCallback) { - if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { - builder.add(afterAgentCallbackInstance); - } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { - builder.add( - (callbackContext) -> - Maybe.fromOptional(afterAgentCallbackSyncInstance.call(callbackContext))); - } else { - logger.warn( - "Invalid afterAgentCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - return builder.build(); } + return callbacks.stream() + .flatMap( + callback -> { + if (asyncClass.isInstance(callback)) { + return Stream.of(asyncClass.cast(callback)); + } else if (syncClass.isInstance(callback)) { + return Stream.of(converter.apply(syncClass.cast(callback))); + } else { + logger.warn( + "Invalid {} callback type: {}. Ignoring this callback.", + callbackTypeForLogging, + callback.getClass().getName()); + return Stream.empty(); + } + }) + .collect(ImmutableList.toImmutableList()); } private CallbackUtil() {} diff --git a/core/src/main/java/com/google/adk/agents/ContextCacheConfig.java b/core/src/main/java/com/google/adk/agents/ContextCacheConfig.java new file mode 100644 index 000000000..084700d54 --- /dev/null +++ b/core/src/main/java/com/google/adk/agents/ContextCacheConfig.java @@ -0,0 +1,59 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.agents; + +import java.time.Duration; + +/** + * Configuration for context caching across all agents in an app. + * + *

This configuration enables and controls context caching behavior for all LLM agents in an app. + * When this config is present on an app, context caching is enabled for all agents. When absent + * (null), context caching is disabled. + * + *

Context caching can significantly reduce costs and improve response times by reusing + * previously processed context across multiple requests. + * + * @param maxInvocations Maximum number of invocations to reuse the same cache before refreshing it. + * Defaults to 10. + * @param ttl Time-to-live for cache. Defaults to 1800 seconds (30 minutes). + * @param minTokens Minimum estimated request tokens required to enable caching. This compares + * against the estimated total tokens of the request (system instruction + tools + contents). + * Context cache storage may have cost. Set higher to avoid caching small requests where + * overhead may exceed benefits. Defaults to 0. + */ +public record ContextCacheConfig(int maxInvocations, Duration ttl, int minTokens) { + + public ContextCacheConfig() { + this(10, Duration.ofSeconds(1800), 0); + } + + /** Returns TTL as string format for cache creation. */ + public String getTtlString() { + return ttl.getSeconds() + "s"; + } + + @Override + public String toString() { + return "ContextCacheConfig(maxInvocations=" + + maxInvocations + + ", ttl=" + + ttl.getSeconds() + + "s, minTokens=" + + minTokens + + ")"; + } +} diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index ace00db4c..6457a8ca4 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -16,20 +16,19 @@ package com.google.adk.agents; -import com.google.adk.apps.ResumabilityConfig; +import static com.google.common.base.Strings.isNullOrEmpty; + import com.google.adk.artifacts.BaseArtifactService; -import com.google.adk.events.Event; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.Plugin; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableSet; +import com.google.adk.summarizer.EventsCompactionConfig; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -50,10 +49,10 @@ public class InvocationContext { private final Session session; private final Optional userContent; private final RunConfig runConfig; - private final Map agentStates; - private final Map endOfAgents; - private final ResumabilityConfig resumabilityConfig; + @Nullable private final EventsCompactionConfig eventsCompactionConfig; + @Nullable private final ContextCacheConfig contextCacheConfig; private final InvocationCostManager invocationCostManager; + private final Map callbackContextData; private Optional branch; private BaseAgent agent; @@ -73,10 +72,10 @@ protected InvocationContext(Builder builder) { this.userContent = builder.userContent; this.runConfig = builder.runConfig; this.endInvocation = builder.endInvocation; - this.agentStates = builder.agentStates; - this.endOfAgents = builder.endOfAgents; - this.resumabilityConfig = builder.resumabilityConfig; + this.eventsCompactionConfig = builder.eventsCompactionConfig; + this.contextCacheConfig = builder.contextCacheConfig; this.invocationCostManager = builder.invocationCostManager; + this.callbackContextData = new ConcurrentHashMap<>(builder.callbackContextData); } /** @@ -257,10 +256,7 @@ public String invocationId() { /** * Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation * history. - * - * @deprecated Use {@link #toBuilder()} and {@link Builder#branch(String)} instead. */ - @Deprecated(forRemoval = true) public void branch(@Nullable String branch) { this.branch = Optional.ofNullable(branch); } @@ -303,14 +299,12 @@ public RunConfig runConfig() { return runConfig; } - /** Returns agent-specific state saved within this invocation. */ - public Map agentStates() { - return agentStates; - } - - /** Returns map of agents that ended during this invocation. */ - public Map endOfAgents() { - return endOfAgents; + /** + * Returns a map for storing temporary context data that can be shared between different parts of + * the invocation (e.g., before/on/after model callbacks). + */ + public Map callbackContextData() { + return callbackContextData; } /** @@ -351,26 +345,14 @@ public void incrementLlmCallsCount() throws LlmCallsLimitExceededException { this.invocationCostManager.incrementAndEnforceLlmCallsLimit(this.runConfig); } - /** Returns whether the current invocation is resumable. */ - public boolean isResumable() { - return resumabilityConfig.isResumable(); + /** Returns the events compaction configuration for the current agent run. */ + public Optional eventsCompactionConfig() { + return Optional.ofNullable(eventsCompactionConfig); } - /** Returns whether to pause the invocation right after this [event]. */ - public boolean shouldPauseInvocation(Event event) { - if (!isResumable()) { - return false; - } - - var longRunningToolIds = event.longRunningToolIds().orElse(ImmutableSet.of()); - if (longRunningToolIds.isEmpty()) { - return false; - } - - return event.functionCalls().stream() - .map(FunctionCall::id) - .flatMap(Optional::stream) - .anyMatch(functionCallId -> longRunningToolIds.contains(functionCallId)); + /** Returns the context cache configuration for the current agent run. */ + public Optional contextCacheConfig() { + return Optional.ofNullable(contextCacheConfig); } private static class InvocationCostManager { @@ -424,10 +406,10 @@ private Builder(InvocationContext context) { this.userContent = context.userContent; this.runConfig = context.runConfig; this.endInvocation = context.endInvocation; - this.agentStates = new ConcurrentHashMap<>(context.agentStates); - this.endOfAgents = new ConcurrentHashMap<>(context.endOfAgents); - this.resumabilityConfig = context.resumabilityConfig; + this.eventsCompactionConfig = context.eventsCompactionConfig; + this.contextCacheConfig = context.contextCacheConfig; this.invocationCostManager = context.invocationCostManager; + this.callbackContextData = new ConcurrentHashMap<>(context.callbackContextData); } private BaseSessionService sessionService; @@ -443,10 +425,10 @@ private Builder(InvocationContext context) { private Optional userContent = Optional.empty(); private RunConfig runConfig = RunConfig.builder().build(); private boolean endInvocation = false; - private Map agentStates = new ConcurrentHashMap<>(); - private Map endOfAgents = new ConcurrentHashMap<>(); - private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); + @Nullable private EventsCompactionConfig eventsCompactionConfig; + @Nullable private ContextCacheConfig contextCacheConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); + private Map callbackContextData = new ConcurrentHashMap<>(); /** * Sets the session service for managing session state. @@ -635,38 +617,38 @@ public Builder endInvocation(boolean endInvocation) { } /** - * Sets agent-specific state saved within this invocation. + * Sets the events compaction configuration for the current agent run. * - * @param agentStates agent-specific state saved within this invocation. + * @param eventsCompactionConfig the events compaction configuration. * @return this builder instance for chaining. */ @CanIgnoreReturnValue - public Builder agentStates(Map agentStates) { - this.agentStates = agentStates; + public Builder eventsCompactionConfig(@Nullable EventsCompactionConfig eventsCompactionConfig) { + this.eventsCompactionConfig = eventsCompactionConfig; return this; } /** - * Sets agent end-of-invocation status. + * Sets the context cache configuration for the current agent run. * - * @param endOfAgents agent end-of-invocation status. + * @param contextCacheConfig the context cache configuration. * @return this builder instance for chaining. */ @CanIgnoreReturnValue - public Builder endOfAgents(Map endOfAgents) { - this.endOfAgents = endOfAgents; + public Builder contextCacheConfig(@Nullable ContextCacheConfig contextCacheConfig) { + this.contextCacheConfig = contextCacheConfig; return this; } /** - * Sets the resumability configuration for the current agent run. + * Sets the callback context data for the invocation. * - * @param resumabilityConfig the resumability configuration. + * @param callbackContextData the callback context data. * @return this builder instance for chaining. */ @CanIgnoreReturnValue - public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { - this.resumabilityConfig = resumabilityConfig; + public Builder callbackContextData(Map callbackContextData) { + this.callbackContextData = callbackContextData; return this; } @@ -675,12 +657,33 @@ public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { * * @throws IllegalStateException if any required parameters are missing. */ - // TODO: b/462183912 - Add validation for required parameters. public InvocationContext build() { + validate(this); return new InvocationContext(this); } } + /** + * Validates the required parameters fields: invocationId, agent, session, and sessionService. + * + * @param builder the builder to validate. + * @throws IllegalStateException if any required parameters are missing. + */ + private static void validate(Builder builder) { + if (isNullOrEmpty(builder.invocationId)) { + throw new IllegalStateException("Invocation ID must be non-empty."); + } + if (builder.agent == null) { + throw new IllegalStateException("Agent must be set."); + } + if (builder.session == null) { + throw new IllegalStateException("Session must be set."); + } + if (builder.sessionService == null) { + throw new IllegalStateException("Session service must be set."); + } + } + @Override public boolean equals(Object o) { if (this == o) { @@ -702,10 +705,10 @@ public boolean equals(Object o) { && Objects.equals(session, that.session) && Objects.equals(userContent, that.userContent) && Objects.equals(runConfig, that.runConfig) - && Objects.equals(agentStates, that.agentStates) - && Objects.equals(endOfAgents, that.endOfAgents) - && Objects.equals(resumabilityConfig, that.resumabilityConfig) - && Objects.equals(invocationCostManager, that.invocationCostManager); + && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) + && Objects.equals(contextCacheConfig, that.contextCacheConfig) + && Objects.equals(invocationCostManager, that.invocationCostManager) + && Objects.equals(callbackContextData, that.callbackContextData); } @Override @@ -724,9 +727,9 @@ public int hashCode() { userContent, runConfig, endInvocation, - agentStates, - endOfAgents, - resumabilityConfig, - invocationCostManager); + eventsCompactionConfig, + contextCacheConfig, + invocationCostManager, + callbackContextData); } } diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 444985971..1893fb162 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNullElse; import static java.util.stream.Collectors.joining; import com.fasterxml.jackson.core.JsonProcessingException; @@ -61,6 +62,7 @@ import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; @@ -103,12 +105,12 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; - private final Optional> beforeModelCallback; - private final Optional> afterModelCallback; - private final Optional> onModelErrorCallback; - private final Optional> beforeToolCallback; - private final Optional> afterToolCallback; - private final Optional> onToolErrorCallback; + private final ImmutableList beforeModelCallback; + private final ImmutableList afterModelCallback; + private final ImmutableList onModelErrorCallback; + private final ImmutableList beforeToolCallback; + private final ImmutableList afterToolCallback; + private final ImmutableList onToolErrorCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -126,29 +128,28 @@ protected LlmAgent(Builder builder) { builder.beforeAgentCallback, builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); - this.instruction = - builder.instruction == null ? new Instruction.Static("") : builder.instruction; + this.instruction = requireNonNullElse(builder.instruction, new Instruction.Static("")); this.globalInstruction = - builder.globalInstruction == null ? new Instruction.Static("") : builder.globalInstruction; + requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); this.exampleProvider = Optional.ofNullable(builder.exampleProvider); - this.includeContents = - builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT; + this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; - this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); - this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); - this.onModelErrorCallback = Optional.ofNullable(builder.onModelErrorCallback); - this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); - this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); - this.onToolErrorCallback = Optional.ofNullable(builder.onToolErrorCallback); + this.beforeModelCallback = requireNonNullElse(builder.beforeModelCallback, ImmutableList.of()); + this.afterModelCallback = requireNonNullElse(builder.afterModelCallback, ImmutableList.of()); + this.onModelErrorCallback = + requireNonNullElse(builder.onModelErrorCallback, ImmutableList.of()); + this.beforeToolCallback = requireNonNullElse(builder.beforeToolCallback, ImmutableList.of()); + this.afterToolCallback = requireNonNullElse(builder.afterToolCallback, ImmutableList.of()); + this.onToolErrorCallback = requireNonNullElse(builder.onToolErrorCallback, ImmutableList.of()); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); this.outputKey = Optional.ofNullable(builder.outputKey); - this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of(); + this.toolsUnion = requireNonNullElse(builder.toolsUnion, ImmutableList.of()); this.toolsets = extractToolsets(this.toolsUnion); this.codeExecutor = Optional.ofNullable(builder.codeExecutor); @@ -841,27 +842,27 @@ public boolean disallowTransferToPeers() { return disallowTransferToPeers; } - public Optional> beforeModelCallback() { + public List beforeModelCallback() { return beforeModelCallback; } - public Optional> afterModelCallback() { + public List afterModelCallback() { return afterModelCallback; } - public Optional> beforeToolCallback() { + public List beforeToolCallback() { return beforeToolCallback; } - public Optional> afterToolCallback() { + public List afterToolCallback() { return afterToolCallback; } - public Optional> onModelErrorCallback() { + public List onModelErrorCallback() { return onModelErrorCallback; } - public Optional> onToolErrorCallback() { + public List onToolErrorCallback() { return onToolErrorCallback; } @@ -871,7 +872,7 @@ public Optional> onToolErrorCallback() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeModelCallbacks() { - return beforeModelCallback.orElse(ImmutableList.of()); + return beforeModelCallback; } /** @@ -880,7 +881,7 @@ public List canonicalBeforeModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterModelCallbacks() { - return afterModelCallback.orElse(ImmutableList.of()); + return afterModelCallback; } /** @@ -889,7 +890,7 @@ public List canonicalAfterModelCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnModelErrorCallbacks() { - return onModelErrorCallback.orElse(ImmutableList.of()); + return onModelErrorCallback; } /** @@ -898,7 +899,7 @@ public List canonicalOnModelErrorCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalBeforeToolCallbacks() { - return beforeToolCallback.orElse(ImmutableList.of()); + return beforeToolCallback; } /** @@ -907,7 +908,7 @@ public List canonicalBeforeToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalAfterToolCallbacks() { - return afterToolCallback.orElse(ImmutableList.of()); + return afterToolCallback; } /** @@ -916,7 +917,7 @@ public List canonicalAfterToolCallbacks() { *

This method is only for use by Agent Development Kit. */ public List canonicalOnToolErrorCallbacks() { - return onToolErrorCallback.orElse(ImmutableList.of()); + return onToolErrorCallback; } public Optional inputSchema() { @@ -935,9 +936,8 @@ public Optional outputKey() { return outputKey; } - @Nullable - public BaseCodeExecutor codeExecutor() { - return codeExecutor.orElse(null); + public Optional codeExecutor() { + return codeExecutor; } public Model resolvedModel() { @@ -1056,6 +1056,26 @@ public static LlmAgent fromConfig(LlmAgentConfig config, String configAbsPath) return agent; } + @Override + public Completable close() { + List completables = new ArrayList<>(); + toolsets() + .forEach( + toolset -> + completables.add( + Completable.fromAction( + () -> { + try { + toolset.close(); + } catch (Exception e) { + logger.error("Failed to close toolset", e); + throw e; + } + }))); + completables.add(super.close()); + return Completable.mergeDelayError(completables); + } + private static void setCallbacksFromConfig(LlmAgentConfig config, Builder builder) throws ConfigurationException { ConfigAgentUtils.resolveAndSetCallback( diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 308169e36..1ca203eaf 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -134,6 +134,9 @@ public abstract Builder setInputAudioTranscription( public RunConfig build() { RunConfig runConfig = autoBuild(); + if (runConfig.maxLlmCalls() == Integer.MAX_VALUE) { + throw new IllegalArgumentException("maxLlmCalls should be less than Integer.MAX_VALUE."); + } if (runConfig.maxLlmCalls() < 0) { logger.warn( "maxLlmCalls is negative. This will result in no enforcement on total" diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 3b1f0613a..18e8753c7 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -17,7 +17,8 @@ package com.google.adk.apps; import com.google.adk.agents.BaseAgent; -import com.google.adk.plugins.BasePlugin; +import com.google.adk.agents.ContextCacheConfig; +import com.google.adk.plugins.Plugin; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -38,21 +39,21 @@ public class App { private final String name; private final BaseAgent rootAgent; - private final ImmutableList plugins; + private final ImmutableList plugins; @Nullable private final EventsCompactionConfig eventsCompactionConfig; - @Nullable private final ResumabilityConfig resumabilityConfig; + @Nullable private final ContextCacheConfig contextCacheConfig; private App( String name, BaseAgent rootAgent, - List plugins, + List plugins, @Nullable EventsCompactionConfig eventsCompactionConfig, - @Nullable ResumabilityConfig resumabilityConfig) { + @Nullable ContextCacheConfig contextCacheConfig) { this.name = name; this.rootAgent = rootAgent; this.plugins = ImmutableList.copyOf(plugins); this.eventsCompactionConfig = eventsCompactionConfig; - this.resumabilityConfig = resumabilityConfig; + this.contextCacheConfig = contextCacheConfig; } public String name() { @@ -63,7 +64,7 @@ public BaseAgent rootAgent() { return rootAgent; } - public ImmutableList plugins() { + public ImmutableList plugins() { return plugins; } @@ -73,17 +74,17 @@ public EventsCompactionConfig eventsCompactionConfig() { } @Nullable - public ResumabilityConfig resumabilityConfig() { - return resumabilityConfig; + public ContextCacheConfig contextCacheConfig() { + return contextCacheConfig; } /** Builder for {@link App}. */ public static class Builder { private String name; private BaseAgent rootAgent; - private List plugins = ImmutableList.of(); + private List plugins = ImmutableList.of(); @Nullable private EventsCompactionConfig eventsCompactionConfig; - @Nullable private ResumabilityConfig resumabilityConfig; + @Nullable private ContextCacheConfig contextCacheConfig; @CanIgnoreReturnValue public Builder name(String name) { @@ -98,7 +99,7 @@ public Builder rootAgent(BaseAgent rootAgent) { } @CanIgnoreReturnValue - public Builder plugins(List plugins) { + public Builder plugins(List plugins) { this.plugins = plugins; return this; } @@ -110,8 +111,8 @@ public Builder eventsCompactionConfig(EventsCompactionConfig eventsCompactionCon } @CanIgnoreReturnValue - public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { - this.resumabilityConfig = resumabilityConfig; + public Builder contextCacheConfig(ContextCacheConfig contextCacheConfig) { + this.contextCacheConfig = contextCacheConfig; return this; } @@ -123,7 +124,7 @@ public App build() { throw new IllegalStateException("Root agent must be provided."); } validateAppName(name); - return new App(name, rootAgent, plugins, eventsCompactionConfig, resumabilityConfig); + return new App(name, rootAgent, plugins, eventsCompactionConfig, contextCacheConfig); } } diff --git a/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java b/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java deleted file mode 100644 index b80ce709c..000000000 --- a/core/src/main/java/com/google/adk/apps/ResumabilityConfig.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS language governing permissions and - * limitations under the License. - */ -package com.google.adk.apps; - -/** - * An app contains Resumability configuration for the agents. - * - * @param isResumable Whether the app is resumable. - */ -public record ResumabilityConfig(boolean isResumable) { - - /** Creates a new {@code ResumabilityConfig} with resumability disabled. */ - public ResumabilityConfig() { - this(false); - } -} diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index 847e88dd9..b6a3cee23 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -39,6 +39,28 @@ public interface BaseArtifactService { Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact); + /** + * Saves an artifact and returns it with fileData if available. + * + *

Implementations should override this default method for efficiency, as the default performs + * two I/O operations (save then load). + * + * @param appName the app name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the filename + * @param artifact the artifact to save + * @return the saved artifact with fileData if available. + */ + default Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMap( + version -> + loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) + .toSingle()); + } + /** * Gets an artifact. * diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index 1bfef8cf8..b9bc49a02 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -18,6 +18,7 @@ import static java.util.Collections.max; +import com.google.auto.value.AutoValue; import com.google.cloud.storage.Blob; import com.google.cloud.storage.BlobId; import com.google.cloud.storage.BlobInfo; @@ -27,6 +28,7 @@ import com.google.common.base.Splitter; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.genai.types.FileData; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; @@ -108,34 +110,8 @@ private String getBlobName( @Override public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { - return listVersions(appName, userId, sessionId, filename) - .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) - .map( - nextVersion -> { - String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); - BlobId blobId = BlobId.of(bucketName, blobName); - - BlobInfo blobInfo = - BlobInfo.newBuilder(blobId) - .setContentType(artifact.inlineData().get().mimeType().orElse(null)) - .build(); - - try { - byte[] dataToSave = - artifact - .inlineData() - .get() - .data() - .orElseThrow( - () -> - new IllegalArgumentException( - "Saveable artifact data must be non-empty.")); - storageClient.create(blobInfo, dataToSave); - return nextVersion; - } catch (StorageException e) { - throw new VerifyException("Failed to save artifact to GCS", e); - } - }); + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .map(SaveResult::version); } /** @@ -275,4 +251,75 @@ public Single> listVersions( return Single.just(ImmutableList.of()); } } + + @Override + public Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .flatMap( + blob -> { + Blob savedBlob = blob.blob(); + String resultMimeType = + Optional.ofNullable(savedBlob.getContentType()) + .or( + () -> + artifact.inlineData().flatMap(com.google.genai.types.Blob::mimeType)) + .orElse("application/octet-stream"); + return Single.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri("gs://" + savedBlob.getBucket() + "/" + savedBlob.getName()) + .mimeType(resultMimeType) + .build()) + .build()); + }); + } + + @AutoValue + abstract static class SaveResult { + static SaveResult create(Blob blob, int version) { + return new AutoValue_GcsArtifactService_SaveResult(blob, version); + } + + abstract Blob blob(); + + abstract int version(); + } + + private Single saveArtifactAndReturnBlob( + String appName, String userId, String sessionId, String filename, Part artifact) { + return listVersions(appName, userId, sessionId, filename) + .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) + .map( + nextVersion -> { + if (artifact.inlineData().isEmpty()) { + throw new IllegalArgumentException("Saveable artifact must have inline data."); + } + + String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); + BlobId blobId = BlobId.of(bucketName, blobName); + + BlobInfo blobInfo = + BlobInfo.newBuilder(blobId) + .setContentType(artifact.inlineData().get().mimeType().orElse(null)) + .build(); + + try { + byte[] dataToSave = + artifact + .inlineData() + .get() + .data() + .orElseThrow( + () -> + new IllegalArgumentException( + "Saveable artifact data must be non-empty.")); + Blob blob = storageClient.create(blobInfo, dataToSave); + return SaveResult.create(blob, nextVersion); + } catch (StorageException e) { + throw new VerifyException("Failed to save artifact to GCS", e); + } + }); + } } diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 890820196..5808f7083 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -48,11 +48,8 @@ public InMemoryArtifactService() { public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { List versions = - artifacts - .computeIfAbsent(appName, k -> new HashMap<>()) - .computeIfAbsent(userId, k -> new HashMap<>()) - .computeIfAbsent(sessionId, k -> new HashMap<>()) - .computeIfAbsent(filename, k -> new ArrayList<>()); + getArtifactsMap(appName, userId, sessionId) + .computeIfAbsent(filename, unused -> new ArrayList<>()); versions.add(artifact); return Single.just(versions.size() - 1); } @@ -66,11 +63,8 @@ public Single saveArtifact( public Maybe loadArtifact( String appName, String userId, String sessionId, String filename, Optional version) { List versions = - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .getOrDefault(filename, new ArrayList<>()); + getArtifactsMap(appName, userId, sessionId) + .computeIfAbsent(filename, unused -> new ArrayList<>()); if (versions.isEmpty()) { return Maybe.empty(); @@ -97,13 +91,7 @@ public Single listArtifactKeys( String appName, String userId, String sessionId) { return Single.just( ListArtifactsResponse.builder() - .filenames( - ImmutableList.copyOf( - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .keySet())) + .filenames(ImmutableList.copyOf(getArtifactsMap(appName, userId, sessionId).keySet())) .build()); } @@ -115,11 +103,7 @@ public Single listArtifactKeys( @Override public Completable deleteArtifact( String appName, String userId, String sessionId, String filename) { - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .remove(filename); + getArtifactsMap(appName, userId, sessionId).remove(filename); return Completable.complete(); } @@ -132,15 +116,29 @@ public Completable deleteArtifact( public Single> listVersions( String appName, String userId, String sessionId, String filename) { int size = - artifacts - .getOrDefault(appName, new HashMap<>()) - .getOrDefault(userId, new HashMap<>()) - .getOrDefault(sessionId, new HashMap<>()) - .getOrDefault(filename, new ArrayList<>()) + getArtifactsMap(appName, userId, sessionId) + .computeIfAbsent(filename, unused -> new ArrayList<>()) .size(); if (size == 0) { return Single.just(ImmutableList.of()); } return Single.just(IntStream.range(0, size).boxed().collect(toImmutableList())); } + + @Override + public Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMap( + version -> + loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) + .toSingle()); + } + + private Map> getArtifactsMap(String appName, String userId, String sessionId) { + return artifacts + .computeIfAbsent(appName, unused -> new HashMap<>()) + .computeIfAbsent(userId, unused -> new HashMap<>()) + .computeIfAbsent(sessionId, unused -> new HashMap<>()); + } } diff --git a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java index bad83ebc0..a34102225 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java +++ b/core/src/main/java/com/google/adk/codeexecutors/CodeExecutorContext.java @@ -89,7 +89,8 @@ public void setExecutionId(String sessionId) { * @return A list of processed file names in the code executor context. */ public List getProcessedFileNames() { - return (List) this.context.getOrDefault(PROCESSED_FILE_NAMES_KEY, new ArrayList<>()); + return (List) + this.context.computeIfAbsent(PROCESSED_FILE_NAMES_KEY, unused -> new ArrayList<>()); } /** @@ -100,7 +101,7 @@ public List getProcessedFileNames() { public void addProcessedFileNames(List fileNames) { List processedFileNames = (List) - this.context.computeIfAbsent(PROCESSED_FILE_NAMES_KEY, k -> new ArrayList<>()); + this.context.computeIfAbsent(PROCESSED_FILE_NAMES_KEY, unused -> new ArrayList<>()); processedFileNames.addAll(fileNames); } @@ -126,7 +127,7 @@ public List getInputFiles() { public void addInputFiles(List inputFiles) { List> fileMaps = (List>) - this.sessionState.computeIfAbsent(INPUT_FILE_KEY, k -> new ArrayList<>()); + this.sessionState.computeIfAbsent(INPUT_FILE_KEY, unused -> new ArrayList<>()); for (File inputFile : inputFiles) { fileMaps.add( objectMapper.convertValue(inputFile, new TypeReference>() {})); @@ -166,7 +167,7 @@ public int getErrorCount(String invocationId) { public void incrementErrorCount(String invocationId) { Map errorCounts = (Map) - this.sessionState.computeIfAbsent(ERROR_COUNT_KEY, k -> new HashMap<>()); + this.sessionState.computeIfAbsent(ERROR_COUNT_KEY, unused -> new HashMap<>()); errorCounts.put(invocationId, getErrorCount(invocationId) + 1); } @@ -176,9 +177,6 @@ public void incrementErrorCount(String invocationId) { * @param invocationId The invocation ID to reset the error count for. */ public void resetErrorCount(String invocationId) { - if (!this.sessionState.containsKey(ERROR_COUNT_KEY)) { - return; - } Map errorCounts = (Map) this.sessionState.get(ERROR_COUNT_KEY); if (errorCounts != null) { @@ -198,9 +196,10 @@ public void updateCodeExecutionResult( String invocationId, String code, String resultStdout, String resultStderr) { Map>> codeExecutionResults = (Map>>) - this.sessionState.computeIfAbsent(CODE_EXECUTION_RESULTS_KEY, k -> new HashMap<>()); + this.sessionState.computeIfAbsent( + CODE_EXECUTION_RESULTS_KEY, unused -> new HashMap<>()); List> resultsForInvocation = - codeExecutionResults.computeIfAbsent(invocationId, k -> new ArrayList<>()); + codeExecutionResults.computeIfAbsent(invocationId, unused -> new ArrayList<>()); Map newResult = new HashMap<>(); newResult.put("code", code); newResult.put("result_stdout", resultStdout); @@ -210,6 +209,7 @@ public void updateCodeExecutionResult( } private Map getCodeExecutorContext(Map sessionState) { - return (Map) sessionState.computeIfAbsent(CONTEXT_KEY, k -> new HashMap<>()); + return (Map) + sessionState.computeIfAbsent(CONTEXT_KEY, unused -> new HashMap<>()); } } diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index 9e05918be..d968efa53 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -294,8 +294,7 @@ public final boolean hasTrailingCodeExecutionResult() { /** Returns true if this is a final response. */ @JsonIgnore public final boolean finalResponse() { - if (actions().skipSummarization().orElse(false) - || (longRunningToolIds().isPresent() && !longRunningToolIds().get().isEmpty())) { + if (actions().skipSummarization().orElse(false)) { return true; } return functionCalls().isEmpty() diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 63909ee1a..6d8c698dd 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -18,11 +18,13 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.google.adk.agents.BaseAgentState; +import com.google.adk.JsonBaseModel; +import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.genai.types.Part; +import java.util.HashSet; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -30,50 +32,44 @@ /** Represents the actions attached to an event. */ // TODO - b/414081262 make json wire camelCase @JsonDeserialize(builder = EventActions.Builder.class) -public class EventActions { +public class EventActions extends JsonBaseModel { private Optional skipSummarization; private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; + private ConcurrentMap artifactDelta; + private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent; - private ConcurrentMap agentState; - private Optional endInvocation; private Optional compaction; - private Optional rewindBeforeInvocationId; /** Default constructor for Jackson. */ public EventActions() { this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); + this.deletedArtifactIds = new HashSet<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); this.endOfAgent = false; - this.endInvocation = Optional.empty(); this.compaction = Optional.empty(); - this.agentState = new ConcurrentHashMap<>(); - this.rewindBeforeInvocationId = Optional.empty(); } private EventActions(Builder builder) { this.skipSummarization = builder.skipSummarization; this.stateDelta = builder.stateDelta; this.artifactDelta = builder.artifactDelta; + this.deletedArtifactIds = builder.deletedArtifactIds; this.transferToAgent = builder.transferToAgent; this.escalate = builder.escalate; this.requestedAuthConfigs = builder.requestedAuthConfigs; this.requestedToolConfirmations = builder.requestedToolConfirmations; this.endOfAgent = builder.endOfAgent; - this.endInvocation = builder.endInvocation; this.compaction = builder.compaction; - this.agentState = builder.agentState; - this.rewindBeforeInvocationId = builder.rewindBeforeInvocationId; } @JsonProperty("skipSummarization") @@ -98,19 +94,39 @@ public ConcurrentMap stateDelta() { return stateDelta; } + @Deprecated // Use stateDelta(), addState() and removeStateByKey() instead. public void setStateDelta(ConcurrentMap stateDelta) { this.stateDelta = stateDelta; } + /** + * Removes a key from the state delta. + * + * @param key The key to remove. + */ + public void removeStateByKey(String key) { + stateDelta.put(key, State.REMOVED); + } + @JsonProperty("artifactDelta") - public ConcurrentMap artifactDelta() { + public ConcurrentMap artifactDelta() { return artifactDelta; } - public void setArtifactDelta(ConcurrentMap artifactDelta) { + public void setArtifactDelta(ConcurrentMap artifactDelta) { this.artifactDelta = artifactDelta; } + @JsonProperty("deletedArtifactIds") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public Set deletedArtifactIds() { + return deletedArtifactIds; + } + + public void setDeletedArtifactIds(Set deletedArtifactIds) { + this.deletedArtifactIds = deletedArtifactIds; + } + @JsonProperty("transferToAgent") public Optional transferToAgent() { return transferToAgent; @@ -167,17 +183,28 @@ public void setEndOfAgent(boolean endOfAgent) { this.endOfAgent = endOfAgent; } - @JsonProperty("endInvocation") + /** + * @deprecated Use {@link #endOfAgent()} instead. + */ + @Deprecated public Optional endInvocation() { - return endInvocation; + return endOfAgent ? Optional.of(true) : Optional.empty(); } + /** + * @deprecated Use {@link #setEndOfAgent(boolean)} instead. + */ + @Deprecated public void setEndInvocation(Optional endInvocation) { - this.endInvocation = endInvocation; + this.endOfAgent = endInvocation.orElse(false); } + /** + * @deprecated Use {@link #setEndOfAgent(boolean)} instead. + */ + @Deprecated public void setEndInvocation(boolean endInvocation) { - this.endInvocation = Optional.of(endInvocation); + this.endOfAgent = endInvocation; } @JsonProperty("compaction") @@ -189,25 +216,6 @@ public void setCompaction(Optional compaction) { this.compaction = compaction; } - @JsonProperty("agentState") - @JsonInclude(JsonInclude.Include.NON_EMPTY) - public ConcurrentMap agentState() { - return agentState; - } - - public void setAgentState(ConcurrentMap agentState) { - this.agentState = agentState; - } - - @JsonProperty("rewindBeforeInvocationId") - public Optional rewindBeforeInvocationId() { - return rewindBeforeInvocationId; - } - - public void setRewindBeforeInvocationId(@Nullable String rewindBeforeInvocationId) { - this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); - } - public static Builder builder() { return new Builder(); } @@ -227,15 +235,13 @@ public boolean equals(Object o) { return Objects.equals(skipSummarization, that.skipSummarization) && Objects.equals(stateDelta, that.stateDelta) && Objects.equals(artifactDelta, that.artifactDelta) + && Objects.equals(deletedArtifactIds, that.deletedArtifactIds) && Objects.equals(transferToAgent, that.transferToAgent) && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) && Objects.equals(requestedToolConfirmations, that.requestedToolConfirmations) && (endOfAgent == that.endOfAgent) - && Objects.equals(endInvocation, that.endInvocation) - && Objects.equals(compaction, that.compaction) - && Objects.equals(agentState, that.agentState) - && Objects.equals(rewindBeforeInvocationId, that.rewindBeforeInvocationId); + && Objects.equals(compaction, that.compaction); } @Override @@ -244,60 +250,52 @@ public int hashCode() { skipSummarization, stateDelta, artifactDelta, + deletedArtifactIds, transferToAgent, escalate, requestedAuthConfigs, requestedToolConfirmations, endOfAgent, - endInvocation, - compaction, - agentState, - rewindBeforeInvocationId); + compaction); } /** Builder for {@link EventActions}. */ public static class Builder { private Optional skipSummarization; private ConcurrentMap stateDelta; - private ConcurrentMap artifactDelta; + private ConcurrentMap artifactDelta; + private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; private ConcurrentMap> requestedAuthConfigs; private ConcurrentMap requestedToolConfirmations; private boolean endOfAgent = false; - private Optional endInvocation; private Optional compaction; - private ConcurrentMap agentState; - private Optional rewindBeforeInvocationId; public Builder() { this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); + this.deletedArtifactIds = new HashSet<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); this.requestedToolConfirmations = new ConcurrentHashMap<>(); - this.endInvocation = Optional.empty(); this.compaction = Optional.empty(); - this.agentState = new ConcurrentHashMap<>(); - this.rewindBeforeInvocationId = Optional.empty(); } private Builder(EventActions eventActions) { this.skipSummarization = eventActions.skipSummarization(); this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); + this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); this.transferToAgent = eventActions.transferToAgent(); this.escalate = eventActions.escalate(); this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); this.requestedToolConfirmations = new ConcurrentHashMap<>(eventActions.requestedToolConfirmations()); this.endOfAgent = eventActions.endOfAgent(); - this.endInvocation = eventActions.endInvocation(); this.compaction = eventActions.compaction(); - this.agentState = new ConcurrentHashMap<>(eventActions.agentState()); - this.rewindBeforeInvocationId = eventActions.rewindBeforeInvocationId(); } @CanIgnoreReturnValue @@ -316,11 +314,18 @@ public Builder stateDelta(ConcurrentMap value) { @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(ConcurrentMap value) { + public Builder artifactDelta(ConcurrentMap value) { this.artifactDelta = value; return this; } + @CanIgnoreReturnValue + @JsonProperty("deletedArtifactIds") + public Builder deletedArtifactIds(Set value) { + this.deletedArtifactIds = value; + return this; + } + @CanIgnoreReturnValue @JsonProperty("transferToAgent") public Builder transferToAgent(String agentId) { @@ -357,10 +362,14 @@ public Builder endOfAgent(boolean endOfAgent) { return this; } + /** + * @deprecated Use {@link #endOfAgent(boolean)} instead. + */ @CanIgnoreReturnValue @JsonProperty("endInvocation") + @Deprecated public Builder endInvocation(boolean endInvocation) { - this.endInvocation = Optional.of(endInvocation); + this.endOfAgent = endInvocation; return this; } @@ -371,34 +380,18 @@ public Builder compaction(EventCompaction value) { return this; } - @CanIgnoreReturnValue - @JsonProperty("agentState") - public Builder agentState(ConcurrentMap agentState) { - this.agentState = agentState; - return this; - } - - @CanIgnoreReturnValue - @JsonProperty("rewindBeforeInvocationId") - public Builder rewindBeforeInvocationId(String rewindBeforeInvocationId) { - this.rewindBeforeInvocationId = Optional.ofNullable(rewindBeforeInvocationId); - return this; - } - @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); this.stateDelta.putAll(other.stateDelta()); this.artifactDelta.putAll(other.artifactDelta()); + this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); other.escalate().ifPresent(this::escalate); this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); this.requestedToolConfirmations.putAll(other.requestedToolConfirmations()); this.endOfAgent = other.endOfAgent(); - other.endInvocation().ifPresent(this::endInvocation); other.compaction().ifPresent(this::compaction); - this.agentState.putAll(other.agentState()); - other.rewindBeforeInvocationId().ifPresent(this::rewindBeforeInvocationId); return this; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 46b3f1952..6ca49ee62 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -144,11 +144,15 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } + Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> - buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest)); + updatedResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); + } + }); } /** @@ -175,45 +179,29 @@ private Flowable callLlm( agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return Flowable.defer( - () -> { - Span llmCallSpan = - Tracing.getTracer() - .spanBuilder("call_llm") - .setParent(Context.current()) - .startSpan(); - - try (Scope scope = llmCallSpan.makeCurrent()) { - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, - llmRequestBuilder, - eventForCallbackUsage, - exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnNext( - llmResp -> { - try (Scope innerScope = llmCallSpan.makeCurrent()) { - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp); - } - }) - .doOnError( - error -> { - llmCallSpan.setStatus(StatusCode.ERROR, error.getMessage()); - llmCallSpan.recordException(error); - }) - .doFinally(llmCallSpan::end); - } + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); }) + .compose(Tracing.trace("call_llm")) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) @@ -230,7 +218,8 @@ private Flowable callLlm( private Single> handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Maybe pluginResult = context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); @@ -267,7 +256,8 @@ private Maybe handleOnModelErrorCallback( Event modelResponseEvent, Throwable throwable) { Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); Maybe pluginResult = @@ -301,7 +291,8 @@ private Maybe handleOnModelErrorCallback( private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); Maybe pluginResult = context.pluginManager().afterModelCallback(callbackContext, llmResponse); @@ -334,73 +325,78 @@ private Single handleAfterModelCallback( */ private Flowable runOneStep(InvocationContext context) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = preprocess(context, llmRequestRef); - - return preprocessEvents.concatWith( - Flowable.defer( - () -> { - LlmRequest llmRequestAfterPreprocess = llmRequestRef.get(); - if (context.endInvocation()) { - logger.debug("End invocation requested during preprocessing."); - return Flowable.empty(); - } - try { - context.incrementLlmCallsCount(); - } catch (LlmCallsLimitExceededException e) { - logger.error("LLM calls limit exceeded.", e); - return Flowable.error(e); - } + return Flowable.defer( + () -> { + Context currentContext = Context.current(); + return preprocess(context, llmRequestRef) + .concatWith( + Flowable.defer( + () -> { + LlmRequest llmRequestAfterPreprocess = llmRequestRef.get(); + if (context.endInvocation()) { + logger.debug("End invocation requested during preprocessing."); + return Flowable.empty(); + } - final Event mutableEventTemplate = - Event.builder() - .id(Event.generateEventId()) - .invocationId(context.invocationId()) - .author(context.agent().name()) - .branch(context.branch()) - .build(); - // Explicitly set the event timestamp to 0 so the postprocessing logic would generate - // events with fresh timestamp. - mutableEventTemplate.setTimestamp(0L); - - return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> - postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - mutableEventTemplate.setId(Event.generateEventId()); - logger.debug( - "Updated mutableEventTemplate ID from {} to {} for" - + " next LlmResponse", - oldId, - mutableEventTemplate.id()); - })) - .concatMap( - event -> { - Flowable postProcessedEvents = Flowable.just(event); - if (event.actions().transferToAgent().isPresent()) { - String agentToTransfer = event.actions().transferToAgent().get(); - logger.debug("Transferring to agent: {}", agentToTransfer); - BaseAgent rootAgent = context.agent().rootAgent(); - BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer); - if (nextAgent == null) { - String errorMsg = "Agent not found for transfer: " + agentToTransfer; - logger.error(errorMsg); - return postProcessedEvents.concatWith( - Flowable.error(new IllegalStateException(errorMsg))); - } - return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.runAsync(context))); + try { + context.incrementLlmCallsCount(); + } catch (LlmCallsLimitExceededException e) { + logger.error("LLM calls limit exceeded.", e); + return Flowable.error(e); } - return postProcessedEvents; - }); - })); + + final Event mutableEventTemplate = + Event.builder() + .id(Event.generateEventId()) + .invocationId(context.invocationId()) + .author(context.agent().name()) + .branch(context.branch()) + .build(); + mutableEventTemplate.setTimestamp(0L); + + return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + .concatMap( + llmResponse -> { + try (Scope postScope = currentContext.makeCurrent()) { + return postprocess( + context, + mutableEventTemplate, + llmRequestAfterPreprocess, + llmResponse) + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug( + "Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); + }); + } + }) + .concatMap( + event -> { + Flowable postProcessedEvents = Flowable.just(event); + if (event.actions().transferToAgent().isPresent()) { + String agentToTransfer = + event.actions().transferToAgent().get(); + BaseAgent rootAgent = context.agent().rootAgent(); + Optional nextAgent = + rootAgent.findAgent(agentToTransfer); + if (nextAgent.isEmpty()) { + logger.error("Agent not found: {}", agentToTransfer); + return postProcessedEvents.concatWith( + Flowable.error( + new IllegalStateException( + "Agent not found: " + agentToTransfer))); + } + return postProcessedEvents.concatWith( + Flowable.defer(() -> nextAgent.get().runAsync(context))); + } + return postProcessedEvents; + }); + })); + }); } /** @@ -435,7 +431,7 @@ private Flowable run(InvocationContext invocationContext, int stepsComple return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return Flowable.defer(() -> run(invocationContext, stepsCompleted + 1)); + return run(invocationContext, stepsCompleted + 1); } })); } @@ -471,40 +467,25 @@ public Flowable runLive(InvocationContext invocationContext) { Completable historySent = llmRequestAfterPreprocess.contents().isEmpty() ? Completable.complete() - : Completable.defer( - () -> { - Span sendDataSpan = - Tracing.getTracer() - .spanBuilder("send_data") - .setParent(Context.current()) - .startSpan(); - try (Scope scope = sendDataSpan.makeCurrent()) { - return connection - .sendHistory(llmRequestAfterPreprocess.contents()) - .doOnComplete( - () -> { - try (Scope innerScope = sendDataSpan.makeCurrent()) { - Tracing.traceSendData( - invocationContext, - eventIdForSendData, - llmRequestAfterPreprocess.contents()); - } - }) - .doOnError( - error -> { - sendDataSpan.setStatus( - StatusCode.ERROR, error.getMessage()); - sendDataSpan.recordException(error); - try (Scope innerScope = sendDataSpan.makeCurrent()) { - Tracing.traceSendData( - invocationContext, - eventIdForSendData, - llmRequestAfterPreprocess.contents()); - } - }) - .doFinally(sendDataSpan::end); - } - }); + : connection + .sendHistory(llmRequestAfterPreprocess.contents()) + .doOnComplete( + () -> + Tracing.traceSendData( + invocationContext, + eventIdForSendData, + llmRequestAfterPreprocess.contents())) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + Tracing.traceSendData( + invocationContext, + eventIdForSendData, + llmRequestAfterPreprocess.contents()); + }) + .compose(Tracing.trace("send_data")); Flowable liveRequests = invocationContext @@ -574,14 +555,14 @@ public void onError(Throwable e) { Flowable events = Flowable.just(event); if (event.actions().transferToAgent().isPresent()) { BaseAgent rootAgent = invocationContext.agent().rootAgent(); - BaseAgent nextAgent = + Optional nextAgent = rootAgent.findAgent(event.actions().transferToAgent().get()); - if (nextAgent == null) { + if (nextAgent.isEmpty()) { throw new IllegalStateException( "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.runLive(invocationContext); + nextAgent.get().runLive(invocationContext); events = Flowable.concat(events, nextAgentEvents); } return events; @@ -677,13 +658,13 @@ private Event buildModelResponseEvent( Event event = eventBuilder.build(); - logger.info("event: {} functionCalls: {}", event, event.functionCalls()); + logger.debug("event: {} functionCalls: {}", event, event.functionCalls()); if (!event.functionCalls().isEmpty()) { Functions.populateClientFunctionCallId(event); Set longRunningToolIds = Functions.getLongRunningFunctionCalls(event.functionCalls(), llmRequest.tools()); - logger.info("longRunningToolIds: {}", longRunningToolIds); + logger.debug("longRunningToolIds: {}", longRunningToolIds); if (!longRunningToolIds.isEmpty()) { event.setLongRunningToolIds(Optional.of(longRunningToolIds)); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java index 64d95cef3..1f99cf4a2 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/CodeExecution.java @@ -22,7 +22,6 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; -import com.google.adk.codeexecutors.BaseCodeExecutor; import com.google.adk.codeexecutors.BuiltInCodeExecutor; import com.google.adk.codeexecutors.CodeExecutionUtils; import com.google.adk.codeexecutors.CodeExecutionUtils.CodeExecutionInput; @@ -108,12 +107,12 @@ private static class CodeExecutionRequestProcessor implements RequestProcessor { public Single processRequest( InvocationContext invocationContext, LlmRequest llmRequest) { if (!(invocationContext.agent() instanceof LlmAgent llmAgent) - || llmAgent.codeExecutor() == null) { + || llmAgent.codeExecutor().isEmpty()) { return Single.just( RequestProcessor.RequestProcessingResult.create(llmRequest, ImmutableList.of())); } - if (llmAgent.codeExecutor() instanceof BuiltInCodeExecutor builtInCodeExecutor) { + if (llmAgent.codeExecutor().get() instanceof BuiltInCodeExecutor builtInCodeExecutor) { var llmRequestBuilder = llmRequest.toBuilder(); builtInCodeExecutor.processLlmRequest(llmRequestBuilder); LlmRequest updatedLlmRequest = llmRequestBuilder.build(); @@ -124,21 +123,27 @@ public Single processRequest( Flowable preprocessorEvents = runPreProcessor(invocationContext, llmRequest); // Convert the code execution parts to text parts. - if (llmAgent.codeExecutor() != null) { - BaseCodeExecutor baseCodeExecutor = llmAgent.codeExecutor(); - List updatedContents = new ArrayList<>(); - for (Content content : llmRequest.contents()) { - List delimiters = - !baseCodeExecutor.codeBlockDelimiters().isEmpty() - ? baseCodeExecutor.codeBlockDelimiters().get(0) - : ImmutableList.of("", ""); - updatedContents.add( - CodeExecutionUtils.convertCodeExecutionParts( - content, delimiters, baseCodeExecutor.executionResultDelimiters())); - } - llmRequest = llmRequest.toBuilder().contents(updatedContents).build(); - } - final LlmRequest finalLlmRequest = llmRequest; + final LlmRequest finalLlmRequest = + llmAgent + .codeExecutor() + .map( + baseCodeExecutor -> { + List delimiters = + !baseCodeExecutor.codeBlockDelimiters().isEmpty() + ? baseCodeExecutor.codeBlockDelimiters().get(0) + : ImmutableList.of("", ""); + ImmutableList updatedContents = + llmRequest.contents().stream() + .map( + content -> + CodeExecutionUtils.convertCodeExecutionParts( + content, + delimiters, + baseCodeExecutor.executionResultDelimiters())) + .collect(toImmutableList()); + return llmRequest.toBuilder().contents(updatedContents).build(); + }) + .orElse(llmRequest); return preprocessorEvents .toList() .map( @@ -173,10 +178,11 @@ private static Flowable runPreProcessor( return Flowable.empty(); } - var codeExecutor = llmAgent.codeExecutor(); - if (codeExecutor == null) { + var codeExecutorOptional = llmAgent.codeExecutor(); + if (codeExecutorOptional.isEmpty()) { return Flowable.empty(); } + var codeExecutor = codeExecutorOptional.get(); if (codeExecutor instanceof BuiltInCodeExecutor) { return Flowable.empty(); @@ -268,10 +274,11 @@ private static Flowable runPostProcessor( if (!(invocationContext.agent() instanceof LlmAgent llmAgent)) { return Flowable.empty(); } - var codeExecutor = llmAgent.codeExecutor(); - if (codeExecutor == null) { + var codeExecutorOptional = llmAgent.codeExecutor(); + if (codeExecutorOptional.isEmpty()) { return Flowable.empty(); } + var codeExecutor = codeExecutorOptional.get(); if (llmResponse.content().isEmpty()) { return Flowable.empty(); } @@ -387,8 +394,8 @@ private static List extractAndReplaceInlineFiles( private static Optional getOrSetExecutionId( InvocationContext invocationContext, CodeExecutorContext codeExecutorContext) { if (!(invocationContext.agent() instanceof LlmAgent llmAgent) - || llmAgent.codeExecutor() == null - || !llmAgent.codeExecutor().stateful()) { + || llmAgent.codeExecutor().isEmpty() + || !llmAgent.codeExecutor().get().stateful()) { return Optional.empty(); } @@ -441,11 +448,9 @@ private static Single postProcessCodeExecutionResult( .toList() .map( versions -> { - ConcurrentMap artifactDelta = new ConcurrentHashMap<>(); + ConcurrentMap artifactDelta = new ConcurrentHashMap<>(); for (int i = 0; i < versions.size(); i++) { - artifactDelta.put( - codeExecutionResult.outputFiles().get(i).name(), - Part.fromText(String.valueOf(versions.get(i)))); + artifactDelta.put(codeExecutionResult.outputFiles().get(i).name(), versions.get(i)); } eventActionsBuilder.artifactDelta(artifactDelta); return Event.builder() diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Compaction.java b/core/src/main/java/com/google/adk/flows/llmflows/Compaction.java new file mode 100644 index 000000000..6646f0ff7 --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/Compaction.java @@ -0,0 +1,59 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.adk.summarizer.TailRetentionEventCompactor; +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; + +/** Request processor that performs event compaction. */ +public class Compaction implements RequestProcessor { + + @Override + public Single processRequest( + InvocationContext context, LlmRequest request) { + Optional configOpt = context.eventsCompactionConfig(); + + if (configOpt.isEmpty()) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + + EventsCompactionConfig config = configOpt.get(); + + if (config.tokenThreshold() == null || config.eventRetentionSize() == null) { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + } + + // Extract out the retention size and token threshold from the new config. + int retentionSize = config.eventRetentionSize(); + int tokenThreshold = config.tokenThreshold(); + + // Summarizer will not be missing since the runner will always add a default one if missing. + TailRetentionEventCompactor compactor = + new TailRetentionEventCompactor(config.summarizer(), retentionSize, tokenThreshold); + + return compactor + .compact(context.session(), context.sessionService()) + .andThen( + Single.just( + RequestProcessor.RequestProcessingResult.create(request, ImmutableList.of()))); + } +} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 0c415f1a8..f45461626 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -38,7 +38,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.ListIterator; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -175,9 +174,18 @@ private boolean isEmptyContent(Event event) { /** * Filters events that are covered by compaction events by identifying compacted ranges and - * filters out events that are covered by compaction summaries + * filters out events that are covered by compaction summaries. Also filters out redundant + * compaction events (i.e., those fully covered by a later compaction event). * - *

Example of input + *

Compaction events are inserted into the stream relative to the events they cover. + * Specifically, a compaction event is placed immediately before the first retained event that + * follows the compaction range (or at the end of the covered range if no events are retained). + * This ensures a logical flow of "Summary of History" -> "Recent/Retained Events". + * + *

Case 1: Sliding Window + Retention + * + *

Compaction events have some overlap but do not fully cover each other. Therefore, all + * compaction events are preserved, as well as the final retained events. * *

    * [
@@ -185,7 +193,7 @@ private boolean isEmptyContent(Event event) {
    *   event_2(timestamp=2),
    *   compaction_1(event_1, event_2, timestamp=3, content=summary_1_2, startTime=1, endTime=2),
    *   event_3(timestamp=4),
-   *   compaction_2(event_2, event_3, timestamp=5, content=summary_2_3, startTime=2, endTime=3),
+   *   compaction_2(event_2, event_3, timestamp=5, content=summary_2_3, startTime=2, endTime=4),
    *   event_4(timestamp=6)
    * ]
    * 
@@ -200,50 +208,123 @@ private boolean isEmptyContent(Event event) { * ] * * - * Compaction events are always strictly in order based on event timestamp. + *

Case 2: Rolling Summary + Retention + * + *

The newer compaction event fully covers the older one. Therefore, the older compaction event + * is removed, leaving only the latest summary and the final retained events. + * + *

+   * [
+   *   event_1(timestamp=1),
+   *   event_2(timestamp=2),
+   *   event_3(timestamp=3),
+   *   event_4(timestamp=4),
+   *   compaction_1(event_1, timestamp=5, content=summary_1, startTime=1, endTime=1),
+   *   event_6(timestamp=6),
+   *   event_7(timestamp=7),
+   *   compaction_2(compaction_1, event_2, event_3, timestamp=8, content=summary_1_3, startTime=1, endTime=3),
+   *   event_9(timestamp=9)
+   * ]
+   * 
+ * + * Will result in the following events output + * + *
+   * [
+   *   compaction_2,
+   *   event_4,
+   *   event_6,
+   *   event_7,
+   *   event_9
+   * ]
+   * 
* * @param events the list of event to filter. * @return a new list with compaction applied. */ private List processCompactionEvent(List events) { + // Step 1: Split events into compaction events and regular events. + List compactionEvents = new ArrayList<>(); + List regularEvents = new ArrayList<>(); + for (Event event : events) { + if (event.actions().compaction().isPresent()) { + compactionEvents.add(event); + } else { + regularEvents.add(event); + } + } + + // Step 2: Remove redundant compaction events (overlapping ones). + compactionEvents = removeOverlappingCompactions(compactionEvents); + + // Step 3: Merge regular events and compaction events based on timestamps. + // We iterate backwards from the latest to the earliest event. List result = new ArrayList<>(); - ListIterator iter = events.listIterator(events.size()); - Long lastCompactionStartTime = null; - Long lastCompactionEndTime = null; - - while (iter.hasPrevious()) { - Event event = iter.previous(); - EventCompaction compaction = event.actions().compaction().orElse(null); - if (compaction == null) { - if (lastCompactionStartTime == null - || event.timestamp() < lastCompactionStartTime - || (lastCompactionEndTime != null && event.timestamp() > lastCompactionEndTime)) { - result.add(event); - } - continue; + int c = compactionEvents.size() - 1; + int e = regularEvents.size() - 1; + while (e >= 0 && c >= 0) { + Event event = regularEvents.get(e); + EventCompaction compaction = compactionEvents.get(c).actions().compaction().get(); + + if (event.timestamp() >= compaction.startTimestamp() + && event.timestamp() <= compaction.endTimestamp()) { + // If the event is covered by compaction, skip it. + e--; + } else if (event.timestamp() > compaction.endTimestamp()) { + // If the event is after compaction, keep it. + result.add(event); + e--; + } else { + // Otherwise the event is before the compaction, let's move to the next compaction event; + result.add(createCompactionEvent(compactionEvents.get(c))); + c--; + } + } + // Flush any remaining compactions. + while (c >= 0) { + result.add(createCompactionEvent(compactionEvents.get(c))); + c--; + } + // Flush any remaining regular events. + while (e >= 0) { + result.add(regularEvents.get(e)); + e--; + } + return Lists.reverse(result); + } + + private static List removeOverlappingCompactions(List events) { + List result = new ArrayList<>(); + // Iterate backwards to prioritize later compactions + for (int i = events.size() - 1; i >= 0; i--) { + Event current = events.get(i); + EventCompaction c = current.actions().compaction().get(); + + // Check if this compaction is covered by the last compaction we've already kept. + boolean covered = false; + if (!result.isEmpty()) { + EventCompaction lastKept = Iterables.getLast(result).actions().compaction().get(); + covered = + c.startTimestamp() >= lastKept.startTimestamp() + && c.endTimestamp() <= lastKept.endTimestamp(); + } + + if (!covered) { + result.add(current); } - // Create a new event for the compaction event in the result. - result.add( - Event.builder() - .timestamp(compaction.endTimestamp()) - .author("model") - .content(compaction.compactedContent()) - .branch(event.branch()) - .invocationId(event.invocationId()) - .actions(event.actions()) - .build()); - lastCompactionStartTime = - lastCompactionStartTime == null - ? compaction.startTimestamp() - : Long.min(lastCompactionStartTime, compaction.startTimestamp()); - lastCompactionEndTime = - lastCompactionEndTime == null - ? compaction.endTimestamp() - : Long.max(lastCompactionEndTime, compaction.endTimestamp()); } return Lists.reverse(result); } + private static Event createCompactionEvent(Event event) { + EventCompaction compaction = event.actions().compaction().get(); + return event.toBuilder() + .timestamp(compaction.endTimestamp()) + .author("model") + .content(compaction.compactedContent()) + .build(); + } + /** Whether the event is a reply from another agent. */ private static boolean isOtherAgentReply(String agentName, Event event) { return !agentName.isEmpty() @@ -483,8 +564,7 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int i = 0; i < events.size(); i++) { Event event = events.get(i); - // Skip response events that will be processed via responseEventsBuffer - if (processedResponseIndices.contains(i)) { + if (!event.functionResponses().isEmpty()) { continue; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index ce7687c3d..82813defa 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -3,7 +3,6 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 @@ -42,7 +41,6 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; @@ -52,7 +50,6 @@ import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Function; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -151,8 +148,9 @@ public static Maybe handleFunctionCalls( } } + Context parentContext = Context.current(); Function> functionCallMapper = - getFunctionCallMapper(invocationContext, tools, toolConfirmations, false); + getFunctionCallMapper(invocationContext, tools, toolConfirmations, false, parentContext); Observable functionResponseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { @@ -178,14 +176,9 @@ public static Maybe handleFunctionCalls( var mergedEvent = maybeMergedEvent.get(); if (events.size() > 1) { - Tracer tracer = Tracing.getTracer(); - Span mergedSpan = - tracer.spanBuilder("tool_response").setParent(Context.current()).startSpan(); - try (Scope scope = mergedSpan.makeCurrent()) { - Tracing.traceToolResponse(invocationContext, mergedEvent.id(), mergedEvent); - } finally { - mergedSpan.end(); - } + return Maybe.just(mergedEvent) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) + .compose(Tracing.trace("tool_response", parentContext)); } return Maybe.just(mergedEvent); }); @@ -217,8 +210,9 @@ public static Maybe handleFunctionCallsLive( } } + Context parentContext = Context.current(); Function> functionCallMapper = - getFunctionCallMapper(invocationContext, tools, toolConfirmations, true); + getFunctionCallMapper(invocationContext, tools, toolConfirmations, true, parentContext); Observable responseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { @@ -245,30 +239,50 @@ private static Function> getFunctionCallMapper( InvocationContext invocationContext, Map tools, Map toolConfirmations, - boolean isLive) { - return functionCall -> { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation(functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = functionCall.args().orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> - isLive - ? processFunctionLive( - invocationContext, tool, toolContext, functionCall, functionArgs) - : callTool(tool, functionArgs, toolContext))); - - return postProcessFunctionResult( - maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive); - }; + boolean isLive, + Context parentContext) { + return functionCall -> + Maybe.defer( + () -> { + try (Scope scope = parentContext.makeCurrent()) { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = functionCall.args().orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> { + try (Scope innerScope = parentContext.makeCurrent()) { + return isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs, + parentContext) + : callTool(tool, functionArgs, toolContext, parentContext); + } + })); + + return postProcessFunctionResult( + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); + } + }); } /** @@ -280,7 +294,8 @@ private static Maybe> processFunctionLive( BaseTool tool, ToolContext toolContext, FunctionCall functionCall, - Map args) { + Map args, + Context parentContext) { // Case 1: Handle a call to stopStreaming if (functionCall.name().get().equals("stopStreaming") && args.containsKey("functionName")) { String functionNameToStop = (String) args.get("functionName"); @@ -332,7 +347,7 @@ private static Maybe> processFunctionLive( ActiveStreamingTool activeTool = invocationContext .activeStreamingTools() - .getOrDefault(tool.name(), new ActiveStreamingTool(subscription)); + .computeIfAbsent(tool.name(), unused -> new ActiveStreamingTool(subscription)); activeTool.task(subscription); invocationContext.activeStreamingTools().put(tool.name(), activeTool); @@ -348,7 +363,7 @@ private static Maybe> processFunctionLive( } // Case 3: Fallback for regular, non-streaming tools - return callTool(tool, args, toolContext); + return callTool(tool, args, toolContext, parentContext); } public static Set getLongRunningFunctionCalls( @@ -372,37 +387,55 @@ private static Maybe postProcessFunctionResult( BaseTool tool, Map functionArgs, ToolContext toolContext, - boolean isLive) { + boolean isLive, + Context parentContext) { return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( - t -> - handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t) - .map(isLive ? Optional::ofNullable : Optional::of) - .switchIfEmpty(Single.error(t))) + t -> { + Maybe> errorCallbackResult = + handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); + } + return mappedResult.switchIfEmpty(Single.error(t)); + }) .flatMapMaybe( optionalInitialResult -> { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - Maybe> afterToolResultMaybe = - maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult); - - return afterToolResultMaybe - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - Event functionResponseEvent = - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext); - return Maybe.just(functionResponseEvent); - }); + try (Scope scope = parentContext.makeCurrent()) { + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, + finalFunctionResult, + toolContext, + invocationContext)) + .compose( + Tracing.trace( + "tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); + }); + } }); } @@ -449,8 +482,12 @@ private static Maybe> maybeInvokeBeforeToolCall( if (invocationContext.agent() instanceof LlmAgent) { LlmAgent agent = (LlmAgent) invocationContext.agent(); + HashMap mutableFunctionArgs = new HashMap<>(functionArgs); + Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + invocationContext + .pluginManager() + .beforeToolCallback(tool, mutableFunctionArgs, toolContext); List callbacks = agent.canonicalBeforeToolCallbacks(); if (callbacks.isEmpty()) { @@ -463,7 +500,8 @@ private static Maybe> maybeInvokeBeforeToolCall( Flowable.fromIterable(callbacks) .concatMapMaybe( callback -> - callback.call(invocationContext, tool, functionArgs, toolContext)) + callback.call( + invocationContext, tool, mutableFunctionArgs, toolContext)) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); @@ -551,27 +589,21 @@ private static Maybe> maybeInvokeAfterToolCall( } private static Maybe> callTool( - BaseTool tool, Map args, ToolContext toolContext) { - Tracer tracer = Tracing.getTracer(); - return Maybe.defer( - () -> { - Span span = - tracer - .spanBuilder("tool_call [" + tool.name() + "]") - .setParent(Context.current()) - .startSpan(); - try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall(args); - return tool.runAsync(args, toolContext) - .toMaybe() - .doOnError(span::recordException) - .doFinally(span::end); - } catch (RuntimeException e) { - span.recordException(e); - span.end(); - return Maybe.error(new RuntimeException("Failed to call tool: " + tool.name(), e)); - } - }); + BaseTool tool, Map args, ToolContext toolContext, Context parentContext) { + return tool.runAsync(args, toolContext) + .toMaybe() + .doOnSubscribe( + d -> + Tracing.traceToolCall( + tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) + .doOnError(t -> Span.current().recordException(t)) + .compose(Tracing.trace("tool_call [" + tool.name() + "]", parentContext)) + .onErrorResumeNext( + e -> + Maybe.error( + e instanceof RuntimeException runtimeException + ? runtimeException + : new RuntimeException("Failed to call tool: " + tool.name(), e))); } private static Event buildResponseEvent( @@ -579,47 +611,27 @@ private static Event buildResponseEvent( Map response, ToolContext toolContext, InvocationContext invocationContext) { - Tracer tracer = Tracing.getTracer(); - Span span = - tracer - .spanBuilder("tool_response [" + tool.name() + "]") - .setParent(Context.current()) - .startSpan(); - try (Scope scope = span.makeCurrent()) { - // use a empty placeholder response if tool response is null. - if (response == null) { - response = new HashMap<>(); - } - - Part partFunctionResponse = - Part.builder() - .functionResponse( - FunctionResponse.builder() - .id(toolContext.functionCallId().orElse("")) - .name(tool.name()) - .response(response) - .build()) - .build(); - - Event event = - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(invocationContext.agent().name()) - .branch(invocationContext.branch()) - .content( - Optional.of( - Content.builder() - .role("user") - .parts(Collections.singletonList(partFunctionResponse)) - .build())) - .actions(toolContext.eventActions()) - .build(); - Tracing.traceToolResponse(invocationContext, event.id(), event); - return event; - } finally { - span.end(); - } + // use an empty placeholder response if tool response is null. + Map finalResponse = response != null ? response : new HashMap<>(); + + Part partFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(toolContext.functionCallId().orElse("")) + .name(tool.name()) + .response(finalResponse) + .build()) + .build(); + + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(invocationContext.agent().name()) + .branch(invocationContext.branch()) + .content(Content.builder().role("user").parts(partFunctionResponse).build()) + .actions(toolContext.eventActions()) + .build(); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index cc2fb443a..de45ba702 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -30,6 +30,7 @@ public class SingleFlow extends BaseLlmFlow { new RequestConfirmationLlmRequestProcessor(), new Instructions(), new Identity(), + new Compaction(), new Contents(), new Examples(), CodeExecution.requestProcessor); diff --git a/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java b/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java index 643d0e9aa..45d81b420 100644 --- a/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java +++ b/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java @@ -34,8 +34,11 @@ import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerToolCall; import com.google.genai.types.Part; +import com.google.genai.types.UsageMetadata; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.processors.PublishProcessor; import java.net.SocketException; import java.util.List; @@ -65,6 +68,7 @@ public final class GeminiLlmConnection implements BaseLlmConnection { private final CompletableFuture sessionFuture; private final PublishProcessor responseProcessor = PublishProcessor.create(); private final Flowable responseFlowable = responseProcessor.serialize(); + private final CompositeDisposable disposables = new CompositeDisposable(); private final AtomicBoolean closed = new AtomicBoolean(false); /** @@ -120,65 +124,116 @@ private void handleServerMessage(LiveServerMessage message) { logger.debug("Received server message: {}", message.toJson()); - Optional llmResponse = convertToServerResponse(message); - llmResponse.ifPresent(responseProcessor::onNext); + Observable llmResponse = convertToServerResponse(message); + if (!disposables.add( + llmResponse.subscribe(responseProcessor::onNext, responseProcessor::onError))) { + logger.warn( + "disposables container already disposed, the subscription will be disposed immediately"); + } } /** Converts a server message into the standardized LlmResponse format. */ - static Optional convertToServerResponse(LiveServerMessage message) { + static Observable convertToServerResponse(LiveServerMessage message) { + return Observable.create( + emitter -> { + // AtomicBoolean is used to modify state from within lambdas, which + // require captured variables to be effectively final. + final AtomicBoolean handled = new AtomicBoolean(false); + message + .serverContent() + .ifPresent( + serverContent -> { + emitter.onNext(createServerContentResponse(serverContent)); + handled.set(true); + }); + message + .toolCall() + .ifPresent( + toolCall -> { + emitter.onNext(createToolCallResponse(toolCall)); + handled.set(true); + }); + message + .usageMetadata() + .ifPresent( + usageMetadata -> { + logger.debug("Received usage metadata: {}", usageMetadata); + emitter.onNext(createUsageMetadataResponse(usageMetadata)); + handled.set(true); + }); + message + .toolCallCancellation() + .ifPresent( + toolCallCancellation -> { + logger.debug("Received tool call cancellation: {}", toolCallCancellation); + // TODO: implement proper CFC and thus tool call cancellation handling. + handled.set(true); + }); + message + .setupComplete() + .ifPresent( + setupComplete -> { + logger.debug("Received setup complete."); + handled.set(true); + }); + + if (!handled.get()) { + logger.warn("Received unknown or empty server message: {}", message.toJson()); + emitter.onNext(createUnknownMessageResponse()); + } + emitter.onComplete(); + }); + } + + private static LlmResponse createServerContentResponse(LiveServerContent serverContent) { LlmResponse.Builder builder = LlmResponse.builder(); + serverContent.modelTurn().ifPresent(builder::content); - if (message.serverContent().isPresent()) { - LiveServerContent serverContent = message.serverContent().get(); - serverContent.modelTurn().ifPresent(builder::content); - builder - .partial(serverContent.turnComplete().map(completed -> !completed).orElse(false)) - .turnComplete(serverContent.turnComplete().orElse(false)) - .interrupted(serverContent.interrupted()); - if (serverContent.outputTranscription().isPresent()) { - Part part = - Part.builder() - .text(serverContent.outputTranscription().get().text().toString()) - .build(); - builder.content(Content.builder().role("model").parts(ImmutableList.of(part)).build()); - } - if (serverContent.inputTranscription().isPresent()) { - Part part = - Part.builder().text(serverContent.inputTranscription().get().text().toString()).build(); - builder.content(Content.builder().role("user").parts(ImmutableList.of(part)).build()); - } - } else if (message.toolCall().isPresent()) { - LiveServerToolCall toolCall = message.toolCall().get(); - toolCall - .functionCalls() - .ifPresent( - calls -> { - for (FunctionCall call : calls) { - builder.content( - Content.builder() - .parts(ImmutableList.of(Part.builder().functionCall(call).build())) - .build()); - } - }); - builder.partial(false).turnComplete(false); - } else if (message.usageMetadata().isPresent()) { - logger.debug("Received usage metadata: {}", message.usageMetadata().get()); - return Optional.empty(); - } else if (message.toolCallCancellation().isPresent()) { - logger.debug("Received tool call cancellation: {}", message.toolCallCancellation().get()); - // TODO: implement proper CFC and thus tool call cancellation handling. - return Optional.empty(); - } else if (message.setupComplete().isPresent()) { - logger.debug("Received setup complete."); - return Optional.empty(); - } else { - logger.warn("Received unknown or empty server message: {}", message.toJson()); - builder - .errorCode(new FinishReason("Unknown server message.")) - .errorMessage("Received unknown server message."); + if (serverContent.outputTranscription().isPresent()) { + Part part = + Part.builder().text(serverContent.outputTranscription().get().text().toString()).build(); + builder.content(Content.builder().role("model").parts(ImmutableList.of(part)).build()); } + if (serverContent.inputTranscription().isPresent()) { + Part part = + Part.builder().text(serverContent.inputTranscription().get().text().toString()).build(); + builder.content(Content.builder().role("user").parts(ImmutableList.of(part)).build()); + } + + return builder + .partial(serverContent.turnComplete().map(completed -> !completed).orElse(false)) + .turnComplete(serverContent.turnComplete().orElse(false)) + .interrupted(serverContent.interrupted()) + .build(); + } + + private static LlmResponse createToolCallResponse(LiveServerToolCall toolCall) { + LlmResponse.Builder builder = LlmResponse.builder(); + toolCall + .functionCalls() + .ifPresent( + calls -> { + for (FunctionCall call : calls) { + builder.content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(call).build())) + .build()); + } + }); + return builder.partial(false).turnComplete(false).build(); + } + + private static LlmResponse createUsageMetadataResponse(UsageMetadata usageMetadata) { + return LlmResponse.builder() + .usageMetadata(GeminiUtil.toGenerateContentResponseUsageMetadata(usageMetadata)) + .build(); + } - return Optional.of(builder.build()); + private static LlmResponse createUnknownMessageResponse() { + return LlmResponse.builder() + .errorCode(new FinishReason("Unknown server message.")) + .errorMessage("Received unknown server message.") + .build(); } /** Handles errors that occur *during* the initial connection attempt. */ @@ -293,6 +348,8 @@ private void closeInternal(Throwable throwable) { } else { sessionFuture.cancel(false); } + + disposables.dispose(); } } diff --git a/core/src/main/java/com/google/adk/models/GeminiUtil.java b/core/src/main/java/com/google/adk/models/GeminiUtil.java index 2b95c0ab2..319226d69 100644 --- a/core/src/main/java/com/google/adk/models/GeminiUtil.java +++ b/core/src/main/java/com/google/adk/models/GeminiUtil.java @@ -24,7 +24,9 @@ import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import com.google.genai.types.UsageMetadata; import java.util.List; import java.util.Optional; import java.util.stream.Stream; @@ -224,4 +226,22 @@ public static List stripThoughts(List originalContents) { }) .collect(toImmutableList()); } + + public static GenerateContentResponseUsageMetadata toGenerateContentResponseUsageMetadata( + UsageMetadata usageMetadata) { + GenerateContentResponseUsageMetadata.Builder builder = + GenerateContentResponseUsageMetadata.builder(); + usageMetadata.promptTokenCount().ifPresent(builder::promptTokenCount); + usageMetadata.cachedContentTokenCount().ifPresent(builder::cachedContentTokenCount); + usageMetadata.responseTokenCount().ifPresent(builder::candidatesTokenCount); + usageMetadata.toolUsePromptTokenCount().ifPresent(builder::toolUsePromptTokenCount); + usageMetadata.thoughtsTokenCount().ifPresent(builder::thoughtsTokenCount); + usageMetadata.totalTokenCount().ifPresent(builder::totalTokenCount); + usageMetadata.promptTokensDetails().ifPresent(builder::promptTokensDetails); + usageMetadata.cacheTokensDetails().ifPresent(builder::cacheTokensDetails); + usageMetadata.responseTokensDetails().ifPresent(builder::candidatesTokensDetails); + usageMetadata.toolUsePromptTokensDetails().ifPresent(builder::toolUsePromptTokensDetails); + usageMetadata.trafficType().ifPresent(builder::trafficType); + return builder.build(); + } } diff --git a/core/src/main/java/com/google/adk/models/GptOssLlm.java b/core/src/main/java/com/google/adk/models/GptOssLlm.java index 331203ac6..895aba540 100644 --- a/core/src/main/java/com/google/adk/models/GptOssLlm.java +++ b/core/src/main/java/com/google/adk/models/GptOssLlm.java @@ -100,16 +100,16 @@ public GptOssLlm(String modelName) { * @param modelName The name of the GPT OSS model to use (e.g., "gpt-oss-4"). * @param vertexCredentials The Vertex AI credentials to access the model. */ -// public GptOssLlm(String modelName, VertexCredentials vertexCredentials) { -// super(modelName); -// Objects.requireNonNull(vertexCredentials, "vertexCredentials cannot be null"); -// Client.Builder apiClientBuilder = -// Client.builder().httpOptions(HttpOptions.builder().headers(TRACKING_HEADERS).build()); -// vertexCredentials.project().ifPresent(apiClientBuilder::project); -// vertexCredentials.location().ifPresent(apiClientBuilder::location); -// vertexCredentials.credentials().ifPresent(apiClientBuilder::credentials); -// this.apiClient = apiClientBuilder.build(); -// } + // public GptOssLlm(String modelName, VertexCredentials vertexCredentials) { + // super(modelName); + // Objects.requireNonNull(vertexCredentials, "vertexCredentials cannot be null"); + // Client.Builder apiClientBuilder = + // Client.builder().httpOptions(HttpOptions.builder().headers(TRACKING_HEADERS).build()); + // vertexCredentials.project().ifPresent(apiClientBuilder::project); + // vertexCredentials.location().ifPresent(apiClientBuilder::location); + // vertexCredentials.credentials().ifPresent(apiClientBuilder::credentials); + // this.apiClient = apiClientBuilder.build(); + // } /** * Returns a new Builder instance for constructing GptOssLlm objects. Note that when building a @@ -165,8 +165,7 @@ public GptOssLlm build() { if (apiClient != null) { return new GptOssLlm(modelName, apiClient); - } - else { + } else { return new GptOssLlm( modelName, Client.builder() @@ -354,4 +353,4 @@ public BaseLlmConnection connect(LlmRequest llmRequest) { return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig); } -} \ No newline at end of file +} diff --git a/core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java b/core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java new file mode 100644 index 000000000..1773bf701 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/GlobalInstructionPlugin.java @@ -0,0 +1,118 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.utils.InstructionUtils; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Plugin that provides global instructions functionality at the App level. + * + *

Global instructions are applied to all agents in the application, providing a consistent way + * to set application-wide instructions, identity, or personality. Global instructions can be + * provided as a static string, or as a function that resolves the instruction based on the {@link + * CallbackContext}. + * + *

The plugin operates through the before_model_callback, allowing it to modify LLM requests + * before they are sent to the model by prepending the global instruction to any existing system + * instructions provided by the agent. + */ +public class GlobalInstructionPlugin extends BasePlugin { + + private final Function> instructionProvider; + + private static Function> createInstructionProvider( + String globalInstruction) { + return callbackContext -> { + if (globalInstruction == null) { + return Maybe.empty(); + } + return InstructionUtils.injectSessionState( + callbackContext.invocationContext(), globalInstruction) + .toMaybe(); + }; + } + + public GlobalInstructionPlugin(String globalInstruction) { + this(globalInstruction, "global_instruction"); + } + + public GlobalInstructionPlugin(String globalInstruction, String name) { + this(createInstructionProvider(globalInstruction), name); + } + + public GlobalInstructionPlugin(Function> instructionProvider) { + this(instructionProvider, "global_instruction"); + } + + public GlobalInstructionPlugin( + Function> instructionProvider, String name) { + super(name); + this.instructionProvider = instructionProvider; + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return instructionProvider + .apply(callbackContext) + .filter(instruction -> !instruction.isEmpty()) + .flatMap( + instruction -> { + // Get mutable config, or create one if it doesn't exist. + GenerateContentConfig config = + llmRequest.config().orElseGet(GenerateContentConfig.builder()::build); + + // Get existing system instruction parts, if any. + Optional systemInstruction = config.systemInstruction(); + List existingParts = + systemInstruction.flatMap(Content::parts).orElse(ImmutableList.of()); + + // Prepend the global instruction to the existing system instruction parts. + // If there are existing instructions, add two newlines between the global + // instruction and the existing instructions. + ImmutableList.Builder newPartsBuilder = ImmutableList.builder(); + if (existingParts.isEmpty()) { + newPartsBuilder.add(Part.fromText(instruction)); + } else { + newPartsBuilder.add(Part.fromText(instruction + "\n\n")); + newPartsBuilder.addAll(existingParts); + } + + // Build the new system instruction content. + Content.Builder newSystemInstructionBuilder = Content.builder(); + systemInstruction.flatMap(Content::role).ifPresent(newSystemInstructionBuilder::role); + newSystemInstructionBuilder.parts(newPartsBuilder.build()); + + // Update llmRequest with new config. + llmRequest.config( + config.toBuilder() + .systemInstruction(newSystemInstructionBuilder.build()) + .build()); + return Maybe.empty(); + }); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/Plugin.java b/core/src/main/java/com/google/adk/plugins/Plugin.java index 97a9038d4..c9cda5680 100644 --- a/core/src/main/java/com/google/adk/plugins/Plugin.java +++ b/core/src/main/java/com/google/adk/plugins/Plugin.java @@ -87,6 +87,16 @@ default Completable afterRunCallback(InvocationContext invocationContext) { return Completable.complete(); } + /** + * Method executed when the runner is closed. + * + *

This method is used for cleanup tasks such as closing network connections or releasing + * resources. + */ + default Completable close() { + return Completable.complete(); + } + /** * Callback executed before an agent's primary logic is invoked. * diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 76452af64..a63d9a402 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -41,16 +41,14 @@ *

The PluginManager is an internal class that orchestrates the invocation of plugin callbacks at * key points in the SDK's execution lifecycle. */ -public class PluginManager implements Plugin { +public class PluginManager extends BasePlugin { private static final Logger logger = LoggerFactory.getLogger(PluginManager.class); - private final List plugins; + private final List plugins = new ArrayList<>(); public PluginManager(List plugins) { - this.plugins = new ArrayList<>(); + super("PluginManager"); if (plugins != null) { - for (var plugin : plugins) { - this.registerPlugin(plugin); - } + plugins.forEach(this::registerPlugin); } } @@ -58,11 +56,6 @@ public PluginManager() { this(null); } - @Override - public String getName() { - return "PluginManager"; - } - /** * Registers a new plugin. * @@ -132,6 +125,19 @@ public Completable afterRunCallback(InvocationContext invocationContext) { e))); } + @Override + public Completable close() { + return Flowable.fromIterable(plugins) + .concatMapCompletableDelayError( + plugin -> + plugin + .close() + .doOnError( + e -> + logger.error( + "[{}] Error during callback 'close'", plugin.getName(), e))); + } + public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { return onEventCallback(invocationContext, event); } @@ -259,7 +265,7 @@ private Maybe runMaybeCallbacks( callbackExecutor .apply(plugin) .doOnSuccess( - r -> + unused -> logger.debug( "Plugin '{}' returned a value for callback '{}', exiting " + "early.", diff --git a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java index 793f24c38..58741003c 100644 --- a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java +++ b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java @@ -19,7 +19,7 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.memory.InMemoryMemoryService; -import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.Plugin; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; import java.util.List; @@ -37,7 +37,7 @@ public InMemoryRunner(BaseAgent agent, String appName) { this(agent, appName, ImmutableList.of()); } - public InMemoryRunner(BaseAgent agent, String appName, List plugins) { + public InMemoryRunner(BaseAgent agent, String appName, List plugins) { super( agent, appName, diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 574c3dcf0..0ddfdaea1 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -18,19 +18,19 @@ import com.google.adk.agents.ActiveStreamingTool; import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.ContextCacheConfig; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.Model; -import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.Plugin; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; @@ -51,7 +51,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -73,8 +72,8 @@ public class Runner { private final BaseSessionService sessionService; @Nullable private final BaseMemoryService memoryService; private final PluginManager pluginManager; - private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; + @Nullable private final ContextCacheConfig contextCacheConfig; /** Builder for {@link Runner}. */ public static class Builder { @@ -84,7 +83,7 @@ public static class Builder { private BaseArtifactService artifactService = new InMemoryArtifactService(); private BaseSessionService sessionService = new InMemorySessionService(); @Nullable private BaseMemoryService memoryService = null; - private List plugins = ImmutableList.of(); + private List plugins = ImmutableList.of(); @CanIgnoreReturnValue public Builder app(App app) { @@ -126,7 +125,7 @@ public Builder memoryService(BaseMemoryService memoryService) { } @CanIgnoreReturnValue - public Builder plugins(List plugins) { + public Builder plugins(List plugins) { Preconditions.checkState(this.app == null, "plugins() cannot be called when app is set."); this.plugins = plugins; return this; @@ -135,9 +134,9 @@ public Builder plugins(List plugins) { public Runner build() { BaseAgent buildAgent; String buildAppName; - List buildPlugins; - ResumabilityConfig buildResumabilityConfig; + List buildPlugins; EventsCompactionConfig buildEventsCompactionConfig; + ContextCacheConfig buildContextCacheConfig; if (this.app != null) { if (this.agent != null) { @@ -149,17 +148,14 @@ public Runner build() { buildAgent = this.app.rootAgent(); buildPlugins = this.app.plugins(); buildAppName = this.appName == null ? this.app.name() : this.appName; - buildResumabilityConfig = - this.app.resumabilityConfig() != null - ? this.app.resumabilityConfig() - : new ResumabilityConfig(); buildEventsCompactionConfig = this.app.eventsCompactionConfig(); + buildContextCacheConfig = this.app.contextCacheConfig(); } else { buildAgent = this.agent; buildAppName = this.appName; buildPlugins = this.plugins; - buildResumabilityConfig = new ResumabilityConfig(); buildEventsCompactionConfig = null; + buildContextCacheConfig = null; } if (buildAgent == null) { @@ -181,8 +177,8 @@ public Runner build() { sessionService, memoryService, buildPlugins, - buildResumabilityConfig, - buildEventsCompactionConfig); + buildEventsCompactionConfig, + buildContextCacheConfig); } } @@ -202,14 +198,7 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService) { - this( - agent, - appName, - artifactService, - sessionService, - memoryService, - ImmutableList.of(), - new ResumabilityConfig()); + this(agent, appName, artifactService, sessionService, memoryService, ImmutableList.of()); } /** @@ -224,44 +213,12 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, - List plugins) { - this( - agent, - appName, - artifactService, - sessionService, - memoryService, - plugins, - new ResumabilityConfig()); + List plugins) { + this(agent, appName, artifactService, sessionService, memoryService, plugins, null, null); } /** - * Creates a new {@code Runner} with a list of plugins and resumability config. - * - * @deprecated Use {@link Runner.Builder} instead. - */ - @Deprecated - public Runner( - BaseAgent agent, - String appName, - BaseArtifactService artifactService, - BaseSessionService sessionService, - @Nullable BaseMemoryService memoryService, - List plugins, - ResumabilityConfig resumabilityConfig) { - this( - agent, - appName, - artifactService, - sessionService, - memoryService, - plugins, - resumabilityConfig, - null); - } - - /** - * Creates a new {@code Runner} with a list of plugins and resumability config. + * Creates a new {@code Runner} with a list of plugins. * * @deprecated Use {@link Runner.Builder} instead. */ @@ -272,17 +229,17 @@ protected Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, - List plugins, - ResumabilityConfig resumabilityConfig, - @Nullable EventsCompactionConfig eventsCompactionConfig) { + List plugins, + @Nullable EventsCompactionConfig eventsCompactionConfig, + @Nullable ContextCacheConfig contextCacheConfig) { this.agent = agent; this.appName = appName; this.artifactService = artifactService; this.sessionService = sessionService; this.memoryService = memoryService; this.pluginManager = new PluginManager(plugins); - this.resumabilityConfig = resumabilityConfig; this.eventsCompactionConfig = createEventsCompactionConfig(agent, eventsCompactionConfig); + this.contextCacheConfig = contextCacheConfig; } /** @@ -324,6 +281,14 @@ public PluginManager pluginManager() { return this.pluginManager; } + /** Closes all plugins, code executors, and releases any resources. */ + public Completable close() { + List completables = new ArrayList<>(); + completables.add(agent.close()); + completables.add(this.pluginManager.close()); + return Completable.mergeDelayError(completables); + } + /** * Appends a new user message to the session history with optional state delta. * @@ -415,7 +380,7 @@ public Flowable runAsync( new IllegalArgumentException( String.format("Session not found: %s for user %s", sessionId, userId))); })) - .flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta)); + .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -449,75 +414,77 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Span span = - Tracing.getTracer().spanBuilder("invocation").setParent(Context.current()).startSpan(); - Context spanContext = Context.current().with(span); - - try { - BaseAgent rootAgent = this.agent; - String invocationId = InvocationContext.newInvocationContextId(); - - // Create initial context - InvocationContext initialContext = - newInvocationContextBuilder(session) - .invocationId(invocationId) - .runConfig(runConfig) - .userContent(newMessage) - .build(); - - return Tracing.traceFlowable( - spanContext, - span, - () -> - Flowable.defer( - () -> - this.pluginManager - .onUserMessageCallback(initialContext, newMessage) - .defaultIfEmpty(newMessage) - .flatMap( - content -> - (content != null) - ? appendNewMessageToSession( - session, - content, - initialContext, - runConfig.saveInputBlobsAsArtifacts(), - stateDelta) - : Single.just(null)) - .flatMapPublisher( - event -> { - if (event == null) { - return Flowable.empty(); - } - // Get the updated session after the message and state delta are - // applied - return this.sessionService - .getSession( - session.appName(), - session.userId(), - session.id(), - Optional.empty()) - .flatMapPublisher( - updatedSession -> - runAgentWithFreshSession( - session, - updatedSession, - event, - invocationId, - runConfig, - rootAgent)); - })) - .doOnError( - throwable -> { - span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); - span.recordException(throwable); - })); - } catch (Throwable t) { - span.setStatus(StatusCode.ERROR, "Error during runAsync synchronous setup"); - span.recordException(t); - span.end(); - return Flowable.error(t); - } + return runAsyncImpl(session, newMessage, runConfig, stateDelta); + } + + /** + * Runs the agent asynchronously using a provided Session object. + * + * @param session The session to run the agent in. + * @param newMessage The new message from the user to process. + * @param runConfig Configuration for the agent run. + * @param stateDelta Optional map of state updates to merge into the session for this run. + * @return A Flowable stream of {@link Event} objects generated by the agent during execution. + */ + protected Flowable runAsyncImpl( + Session session, + Content newMessage, + RunConfig runConfig, + @Nullable Map stateDelta) { + return Flowable.defer( + () -> { + BaseAgent rootAgent = this.agent; + String invocationId = InvocationContext.newInvocationContextId(); + + // Create initial context + InvocationContext initialContext = + newInvocationContextBuilder(session) + .invocationId(invocationId) + .runConfig(runConfig) + .userContent(newMessage) + .build(); + + return this.pluginManager + .onUserMessageCallback(initialContext, newMessage) + .defaultIfEmpty(newMessage) + .flatMap( + content -> + (content != null) + ? appendNewMessageToSession( + session, + content, + initialContext, + runConfig.saveInputBlobsAsArtifacts(), + stateDelta) + : Single.just(null)) + .flatMapPublisher( + event -> { + if (event == null) { + return Flowable.empty(); + } + // Get the updated session after the message and state delta are + // applied + return this.sessionService + .getSession( + session.appName(), session.userId(), session.id(), Optional.empty()) + .flatMapPublisher( + updatedSession -> + runAgentWithFreshSession( + session, + updatedSession, + event, + invocationId, + runConfig, + rootAgent)); + }); + }) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.trace("invocation")); } private Flowable runAgentWithFreshSession( @@ -580,6 +547,7 @@ private Flowable runAgentWithFreshSession( private Completable compactEvents(Session session) { return Optional.ofNullable(eventsCompactionConfig) + .filter(EventsCompactionConfig::hasSlidingWindowCompactionConfig) .map(SlidingWindowEventCompactor::new) .map(c -> c.compact(session, sessionService)) .orElseGet(Completable::complete); @@ -633,7 +601,8 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .pluginManager(this.pluginManager) .agent(rootAgent) .session(session) - .resumabilityConfig(this.resumabilityConfig) + .eventsCompactionConfig(this.eventsCompactionConfig) + .contextCacheConfig(this.contextCacheConfig) .agent(this.findAgentToRun(session, rootAgent)); } @@ -644,52 +613,39 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - Span span = - Tracing.getTracer().spanBuilder("invocation").setParent(Context.current()).startSpan(); - Context spanContext = Context.current().with(span); - - try { - InvocationContext invocationContext = - newInvocationContextForLive(session, Optional.of(liveRequestQueue), runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - - return invocationContextSingle.flatMapPublisher( - updatedInvocationContext -> - Tracing.traceFlowable( - spanContext, - span, - () -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event)) - .onErrorResumeNext( - throwable -> { - span.setStatus( - StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - span.end(); - return Flowable.error(throwable); - }))); - } catch (Throwable t) { - span.setStatus(StatusCode.ERROR, "Error during runLive synchronous setup"); - span.recordException(t); - span.end(); - return Flowable.error(t); - } + return Flowable.defer( + () -> { + InvocationContext invocationContext = + newInvocationContextForLive(session, Optional.of(liveRequestQueue), runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }) + .compose(Tracing.trace("invocation")); } /** @@ -720,6 +676,7 @@ public Flowable runLive( * * @return stream of generated events. */ + @Deprecated(since = "0.5.0", forRemoval = true) public Flowable runWithSessionId( String sessionId, Content newMessage, RunConfig runConfig) { // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. @@ -767,14 +724,14 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) { return rootAgent; } - BaseAgent agent = rootAgent.findSubAgent(author); + Optional agent = rootAgent.findSubAgent(author); - if (agent == null) { + if (agent.isEmpty()) { continue; } - if (this.isTransferableAcrossAgentTree(agent)) { - return agent; + if (this.isTransferableAcrossAgentTree(agent.get())) { + return agent.get(); } } @@ -816,7 +773,11 @@ private static EventsCompactionConfig createEventsCompactionConfig( new IllegalArgumentException( "No BaseLlm model available for event compaction")); return new EventsCompactionConfig( - config.compactionInterval(), config.overlapSize(), summarizer); + config.compactionInterval(), + config.overlapSize(), + summarizer, + config.tokenThreshold(), + config.eventRetentionSize()); } // TODO: run statelessly diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 80c277fce..060fcaf60 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -96,8 +96,8 @@ public Single createSession( .build(); sessions - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(resolvedSessionId, newSession); // Create a mutable copy for the return value @@ -116,8 +116,8 @@ public Maybe getSession( Session storedSession = sessions - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .get(sessionId); if (storedSession == null) { @@ -166,7 +166,7 @@ public Single listSessions(String appName, String userId) Objects.requireNonNull(userId, "userId cannot be null"); Map userSessionsMap = - sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId); + sessions.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()).get(userId); if (userSessionsMap == null || userSessionsMap.isEmpty()) { return Single.just(ListSessionsResponse.builder().build()); @@ -185,11 +185,12 @@ public Completable deleteSession(String appName, String userId, String sessionId Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - ConcurrentMap userSessionsMap = - sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId); - - if (userSessionsMap != null) { - userSessionsMap.remove(sessionId); + ConcurrentMap> appSessionsMap = sessions.get(appName); + if (appSessionsMap != null) { + ConcurrentMap userSessionsMap = appSessionsMap.get(userId); + if (userSessionsMap != null) { + userSessionsMap.remove(sessionId); + } } return Completable.complete(); } @@ -202,8 +203,8 @@ public Single listEvents(String appName, String userId, Stri Session storedSession = sessions - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .get(sessionId); if (storedSession == null) { @@ -236,17 +237,34 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - appState - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .put(appStateKey, value); + if (value == State.REMOVED) { + appState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .remove(appStateKey); + } else { + appState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .put(appStateKey, value); + } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - userState - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) - .put(userStateKey, value); + if (value == State.REMOVED) { + userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + .remove(userStateKey); + } else { + userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + .put(userStateKey, value); + } } else { - session.state().put(key, value); + if (value == State.REMOVED) { + session.state().remove(key); + } else { + session.state().put(key, value); + } } }); } @@ -257,8 +275,8 @@ public Single appendEvent(Session session, Event event) { // --- Update the session stored in this service --- sessions - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(sessionId, session); mergeWithGlobalState(appName, userId, session); @@ -307,12 +325,12 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess // Merge App State directly into the session's state map appState - .getOrDefault(appName, new ConcurrentHashMap()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); userState - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); return session; diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 5dbbe76c7..71b072695 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -19,7 +19,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; -import com.google.adk.agents.BaseAgentState; import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.events.ToolConfirmation; @@ -28,10 +27,11 @@ import com.google.common.collect.Iterables; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.GroundingMetadata; -import com.google.genai.types.Part; import java.io.UncheckedIOException; import java.time.Instant; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -57,62 +57,66 @@ private SessionJsonConverter() {} * @throws UncheckedIOException if serialization fails. */ static String convertEventToJson(Event event) { - Map metadataJson = new HashMap<>(); - metadataJson.put("partial", event.partial()); - metadataJson.put("turnComplete", event.turnComplete()); - metadataJson.put("interrupted", event.interrupted()); - metadataJson.put("branch", event.branch().orElse(null)); - metadataJson.put( - "long_running_tool_ids", - event.longRunningToolIds() != null ? event.longRunningToolIds().orElse(null) : null); - if (event.groundingMetadata() != null) { - metadataJson.put("grounding_metadata", event.groundingMetadata()); - } + return convertEventToJson(event, false); + } + /** + * Converts an {@link Event} to its JSON string representation for API transmission. + * + * @param useIsoString if true, use ISO-8601 string for timestamp; otherwise use object format. + * @return JSON string of the event. + * @throws UncheckedIOException if serialization fails. + */ + static String convertEventToJson(Event event, boolean useIsoString) { + Map metadataJson = new HashMap<>(); + event.partial().ifPresent(v -> metadataJson.put("partial", v)); + event.turnComplete().ifPresent(v -> metadataJson.put("turnComplete", v)); + event.interrupted().ifPresent(v -> metadataJson.put("interrupted", v)); + event.branch().ifPresent(v -> metadataJson.put("branch", v)); + putIfNotEmpty(metadataJson, "longRunningToolIds", event.longRunningToolIds()); + event.groundingMetadata().ifPresent(v -> metadataJson.put("groundingMetadata", v)); + event.usageMetadata().ifPresent(v -> metadataJson.put("usageMetadata", v)); Map eventJson = new HashMap<>(); eventJson.put("author", event.author()); eventJson.put("invocationId", event.invocationId()); - eventJson.put( - "timestamp", - new HashMap<>( - ImmutableMap.of( - "seconds", - event.timestamp() / 1000, - "nanos", - (event.timestamp() % 1000) * 1000000))); - if (event.errorCode().isPresent()) { - eventJson.put("errorCode", event.errorCode()); - } - if (event.errorMessage().isPresent()) { - eventJson.put("errorMessage", event.errorMessage()); + if (useIsoString) { + eventJson.put("timestamp", Instant.ofEpochMilli(event.timestamp()).toString()); + } else { + eventJson.put( + "timestamp", + new HashMap<>( + ImmutableMap.of( + "seconds", + event.timestamp() / 1000, + "nanos", + (event.timestamp() % 1000) * 1000000))); } + event.errorCode().ifPresent(errorCode -> eventJson.put("errorCode", errorCode)); + event.errorMessage().ifPresent(errorMessage -> eventJson.put("errorMessage", errorMessage)); eventJson.put("eventMetadata", metadataJson); if (event.actions() != null) { Map actionsJson = new HashMap<>(); - actionsJson.put("skipSummarization", event.actions().skipSummarization()); - actionsJson.put("stateDelta", event.actions().stateDelta()); - actionsJson.put("artifactDelta", event.actions().artifactDelta()); - actionsJson.put("transferAgent", event.actions().transferToAgent()); - actionsJson.put("escalate", event.actions().escalate()); - actionsJson.put("requestedAuthConfigs", event.actions().requestedAuthConfigs()); - actionsJson.put("requestedToolConfirmations", event.actions().requestedToolConfirmations()); - actionsJson.put("compaction", event.actions().compaction()); - if (!event.actions().agentState().isEmpty()) { - actionsJson.put("agentState", event.actions().agentState()); + EventActions actions = event.actions(); + actions.skipSummarization().ifPresent(v -> actionsJson.put("skipSummarization", v)); + actionsJson.put("stateDelta", stateDeltaToJson(actions.stateDelta())); + putIfNotEmpty(actionsJson, "artifactDelta", actions.artifactDelta()); + actions + .transferToAgent() + .ifPresent( + v -> { + actionsJson.put("transferAgent", v); + }); + actions.escalate().ifPresent(v -> actionsJson.put("escalate", v)); + if (actions.endOfAgent()) { + actionsJson.put("endOfAgent", actions.endOfAgent()); } - actionsJson.put("rewindBeforeInvocationId", event.actions().rewindBeforeInvocationId()); + putIfNotEmpty(actionsJson, "requestedAuthConfigs", actions.requestedAuthConfigs()); + putIfNotEmpty( + actionsJson, "requestedToolConfirmations", actions.requestedToolConfirmations()); eventJson.put("actions", actionsJson); } - if (event.content().isPresent()) { - eventJson.put("content", SessionUtils.encodeContent(event.content().get())); - } - if (event.errorCode().isPresent()) { - eventJson.put("errorCode", event.errorCode().get()); - } - if (event.errorMessage().isPresent()) { - eventJson.put("errorMessage", event.errorMessage().get()); - } + event.content().ifPresent(c -> eventJson.put("content", SessionUtils.encodeContent(c))); try { return objectMapper.writeValueAsString(eventJson); } catch (JsonProcessingException e) { @@ -126,8 +130,7 @@ static String convertEventToJson(Event event) { * @return parsed {@link Content}, or {@code null} if conversion fails. */ @Nullable - // Safe because we check instanceof Map before casting. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting. private static Content convertMapToContent(Object rawContentValue) { if (rawContentValue == null) { return null; @@ -153,26 +156,33 @@ private static Content convertMapToContent(Object rawContentValue) { * * @return parsed {@link Event}. */ - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. static Event fromApiEvent(Map apiEvent) { EventActions.Builder eventActionsBuilder = EventActions.builder(); - if (apiEvent.get("actions") != null) { - Map actionsMap = (Map) apiEvent.get("actions"); - if (actionsMap.get("skipSummarization") != null) { - eventActionsBuilder.skipSummarization((Boolean) actionsMap.get("skipSummarization")); + Map actionsMap = (Map) apiEvent.get("actions"); + if (actionsMap != null) { + Boolean skipSummarization = (Boolean) actionsMap.get("skipSummarization"); + if (skipSummarization != null) { + eventActionsBuilder.skipSummarization(skipSummarization); } - eventActionsBuilder.stateDelta( - actionsMap.get("stateDelta") != null - ? new ConcurrentHashMap<>((Map) actionsMap.get("stateDelta")) - : new ConcurrentHashMap<>()); + eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta"))); + Object artifactDelta = actionsMap.get("artifactDelta"); eventActionsBuilder.artifactDelta( - actionsMap.get("artifactDelta") != null - ? convertToArtifactDeltaMap(actionsMap.get("artifactDelta")) + artifactDelta != null + ? convertToArtifactDeltaMap(artifactDelta) : new ConcurrentHashMap<>()); - eventActionsBuilder.transferToAgent((String) actionsMap.get("transferAgent")); - if (actionsMap.get("escalate") != null) { - eventActionsBuilder.escalate((Boolean) actionsMap.get("escalate")); + String transferAgent = (String) actionsMap.get("transferAgent"); + if (transferAgent == null) { + transferAgent = (String) actionsMap.get("transferToAgent"); + } + eventActionsBuilder.transferToAgent(transferAgent); + Boolean escalate = (Boolean) actionsMap.get("escalate"); + if (escalate != null) { + eventActionsBuilder.escalate(escalate); + } + Boolean endOfAgent = (Boolean) actionsMap.get("endOfAgent"); + if (endOfAgent != null) { + eventActionsBuilder.endOfAgent(endOfAgent); } eventActionsBuilder.requestedAuthConfigs( Optional.ofNullable(actionsMap.get("requestedAuthConfigs")) @@ -182,13 +192,6 @@ static Event fromApiEvent(Map apiEvent) { Optional.ofNullable(actionsMap.get("requestedToolConfirmations")) .map(SessionJsonConverter::asConcurrentMapOfToolConfirmations) .orElse(new ConcurrentHashMap<>())); - if (actionsMap.get("agentState") != null) { - eventActionsBuilder.agentState(asConcurrentMapOfAgentState(actionsMap.get("agentState"))); - } - if (actionsMap.get("rewindBeforeInvocationId") != null) { - eventActionsBuilder.rewindBeforeInvocationId( - (String) actionsMap.get("rewindBeforeInvocationId")); - } } Event event = @@ -208,11 +211,9 @@ static Event fromApiEvent(Map apiEvent) { .map(value -> new FinishReason((String) value))) .errorMessage( Optional.ofNullable(apiEvent.get("errorMessage")).map(value -> (String) value)) - .branch(Optional.ofNullable(apiEvent.get("branch")).map(value -> (String) value)) .build(); - // TODO(b/414263934): Add Event branch and grounding metadata for python parity. - if (apiEvent.get("eventMetadata") != null) { - Map eventMetadata = (Map) apiEvent.get("eventMetadata"); + Map eventMetadata = (Map) apiEvent.get("eventMetadata"); + if (eventMetadata != null) { List longRunningToolIdsList = (List) eventMetadata.get("longRunningToolIds"); GroundingMetadata groundingMetadata = null; @@ -221,6 +222,12 @@ static Event fromApiEvent(Map apiEvent) { groundingMetadata = objectMapper.convertValue(rawGroundingMetadata, GroundingMetadata.class); } + GenerateContentResponseUsageMetadata usageMetadata = null; + Object rawUsageMetadata = eventMetadata.get("usageMetadata"); + if (rawUsageMetadata != null) { + usageMetadata = + objectMapper.convertValue(rawUsageMetadata, GenerateContentResponseUsageMetadata.class); + } event = event.toBuilder() @@ -231,6 +238,7 @@ static Event fromApiEvent(Map apiEvent) { Optional.ofNullable((Boolean) eventMetadata.get("interrupted")).orElse(false)) .branch(Optional.ofNullable((String) eventMetadata.get("branch"))) .groundingMetadata(groundingMetadata) + .usageMetadata(usageMetadata) .longRunningToolIds( longRunningToolIdsList != null ? new HashSet<>(longRunningToolIdsList) : null) .build(); @@ -238,6 +246,32 @@ static Event fromApiEvent(Map apiEvent) { return event; } + @SuppressWarnings("unchecked") // stateDeltaFromMap is a Map from JSON. + private static ConcurrentMap stateDeltaFromJson(Object stateDeltaFromMap) { + if (stateDeltaFromMap == null) { + return new ConcurrentHashMap<>(); + } + return ((Map) stateDeltaFromMap) + .entrySet().stream() + .collect( + ConcurrentHashMap::new, + (map, entry) -> + map.put( + entry.getKey(), + entry.getValue() == null ? State.REMOVED : entry.getValue()), + ConcurrentHashMap::putAll); + } + + private static Map stateDeltaToJson(Map stateDelta) { + return stateDelta.entrySet().stream() + .collect( + HashMap::new, + (map, entry) -> + map.put( + entry.getKey(), entry.getValue() == State.REMOVED ? null : entry.getValue()), + HashMap::putAll); + } + /** * Converts a timestamp from a Map or String into an {@link Instant}. * @@ -263,20 +297,20 @@ private static Instant convertToInstant(Object timestampObj) { * @param artifactDeltaObj The raw object from which to parse the artifact delta. * @return A {@link ConcurrentMap} representing the artifact delta. */ - // Safe because we check instanceof Map before casting. @SuppressWarnings("unchecked") - private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { + private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { return new ConcurrentHashMap<>(); } - ConcurrentMap artifactDeltaMap = new ConcurrentHashMap<>(); - Map> rawMap = (Map>) artifactDeltaObj; - for (Map.Entry> entry : rawMap.entrySet()) { + ConcurrentMap artifactDeltaMap = new ConcurrentHashMap<>(); + Map rawMap = (Map) artifactDeltaObj; + for (Map.Entry entry : rawMap.entrySet()) { try { - Part part = objectMapper.convertValue(entry.getValue(), Part.class); - artifactDeltaMap.put(entry.getKey(), part); + Integer value = objectMapper.convertValue(entry.getValue(), Integer.class); + artifactDeltaMap.put(entry.getKey(), value); } catch (IllegalArgumentException e) { - logger.warn("Error converting artifactDelta value to Part for key: {}", entry.getKey(), e); + logger.warn( + "Error converting artifactDelta value to Integer for key: {}", entry.getKey(), e); } } return artifactDeltaMap; @@ -287,8 +321,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti * * @return thread-safe nested map. */ - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap> asConcurrentMapOfConcurrentMaps(Object value) { return ((Map>) value) @@ -299,22 +332,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti ConcurrentHashMap::putAll); } - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") - private static ConcurrentMap asConcurrentMapOfAgentState(Object value) { - return ((Map) value) - .entrySet().stream() - .collect( - ConcurrentHashMap::new, - (map, entry) -> - map.put( - entry.getKey(), - objectMapper.convertValue(entry.getValue(), BaseAgentState.class)), - ConcurrentHashMap::putAll); - } - - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap asConcurrentMapOfToolConfirmations( Object value) { return ((Map) value) @@ -327,4 +345,22 @@ private static ConcurrentMap asConcurrentMapOfToolConf objectMapper.convertValue(entry.getValue(), ToolConfirmation.class)), ConcurrentHashMap::putAll); } + + private static void putIfNotEmpty(Map map, String key, Map values) { + if (values != null && !values.isEmpty()) { + map.put(key, values); + } + } + + private static void putIfNotEmpty( + Map map, String key, Optional> values) { + values.ifPresent(v -> putIfNotEmpty(map, key, v)); + } + + private static void putIfNotEmpty( + Map map, String key, @Nullable Collection values) { + if (values != null && !values.isEmpty()) { + map.put(key, values); + } + } } diff --git a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java index 766041304..b61cd2008 100644 --- a/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java +++ b/core/src/main/java/com/google/adk/summarizer/EventsCompactionConfig.java @@ -27,11 +27,28 @@ * compacted range. This creates an overlap between consecutive compacted summaries, maintaining * context. * @param summarizer An event summarizer to use for compaction. + * @param tokenThreshold The number of tokens above which compaction will be triggered. If null, no + * token limit will be enforced. It will trigger compaction within the invocation. + * @param eventRetentionSize The maximum number of events to retain and preserve from compaction. If + * null, no event retention limit will be enforced. */ public record EventsCompactionConfig( - int compactionInterval, int overlapSize, @Nullable BaseEventSummarizer summarizer) { + @Nullable Integer compactionInterval, + @Nullable Integer overlapSize, + @Nullable BaseEventSummarizer summarizer, + @Nullable Integer tokenThreshold, + @Nullable Integer eventRetentionSize) { public EventsCompactionConfig(int compactionInterval, int overlapSize) { - this(compactionInterval, overlapSize, null); + this(compactionInterval, overlapSize, null, null, null); + } + + public EventsCompactionConfig( + int compactionInterval, int overlapSize, @Nullable BaseEventSummarizer summarizer) { + this(compactionInterval, overlapSize, summarizer, null, null); + } + + public boolean hasSlidingWindowCompactionConfig() { + return compactionInterval != null && compactionInterval > 0 && overlapSize != null; } } diff --git a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java new file mode 100644 index 000000000..e193e7686 --- /dev/null +++ b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java @@ -0,0 +1,241 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.summarizer; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.adk.events.Event; +import com.google.adk.events.EventCompaction; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.Lists; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class performs event compaction by retaining the tail of the event stream. + * + *

    + *
  • Keeps the {@code retentionSize} most recent events raw. + *
  • Compacts all events that never compacted and older than the retained tail, including the + * most recent compaction event, into a new summary event. + *
  • Triggers compaction only if the prompt token count exceeds the {@code tokenThreshold}. + *
  • The new summary event is generated by the {@link BaseEventSummarizer}. + *
  • Appends this new summary event to the end of the event stream. + *
+ * + *

This compactor produces a rolling summary. Each new compaction event includes the content of + * the previous compaction event (if any) along with new events, effectively superseding all prior + * compactions. + */ +public final class TailRetentionEventCompactor implements EventCompactor { + + private static final Logger logger = LoggerFactory.getLogger(TailRetentionEventCompactor.class); + + private final BaseEventSummarizer summarizer; + private final int retentionSize; + private final int tokenThreshold; + + public TailRetentionEventCompactor( + BaseEventSummarizer summarizer, int retentionSize, int tokenThreshold) { + checkArgument(tokenThreshold >= 0, "tokenThreshold must be non-negative"); + checkArgument(retentionSize >= 0, "retentionSize must be non-negative"); + this.summarizer = summarizer; + this.retentionSize = retentionSize; + this.tokenThreshold = tokenThreshold; + } + + @Override + public Completable compact(Session session, BaseSessionService sessionService) { + checkArgument(summarizer != null, "Missing BaseEventSummarizer for event compaction"); + logger.debug("Running tail retention event compaction for session {}", session.id()); + + return Maybe.just(session.events()) + .flatMap(this::getCompactionEvents) + .flatMap(summarizer::summarizeEvents) + .flatMapSingle(e -> sessionService.appendEvent(session, e)) + .ignoreElement(); + } + + /** + * Identifies events to be compacted based on the tail retention strategy. + * + *

This method iterates backwards through the event list to find the most recent compaction + * event (if any) and collects all uncompacted events that occurred after the range covered by + * that compaction. It then applies the retention policy, excluding the most recent {@code + * retentionSize} events from being compacted. + * + *

Basic Scenario: + * + *

    + *
  • Events: E1, E2, E3, E4, E5 (Chronological order) + *
  • Retention Size: 2 + *
  • Action: Compaction is triggered. The compactor identifies E1, E2, and E3 as eligible + * since E4, E5 need to be retained. + *
  • Result: E1, E2, E3 are compacted into C1. + *
  • Event stream after compaction: E1, E2, E3, E4, E5, C1. (Compaction event is appended in + * the end.) + *
+ * + *

Advanced Scenario (Handling Gaps): + * + *

Consider an edge case where retention size is 3. Event E4 appears before the last compaction + * event (C2) and even the one prior (C1), but remains uncompacted and must be included in the + * third compaction (C3). + * + *

    + *
  • T=1: E1 + *
  • T=2: E2 + *
  • T=3: E3 + *
  • T=4: E4 + *
  • T=5: C1 (Covers T=1). Generated when getCompactionEvents returned List: E1. E2, + * E3, E4 were preserved. + *
  • T=6: E6 + *
  • T=7: E7 + *
  • T=8: C2 (Covers T=1 to T=3; starts at T=1 because it includes C1). Generated when + * getCompactionEvents returned List: C1, E2, E3. E4, E6, E7 were preserved. + *
  • T=9: E9. + *
+ * + *

Execution with Retention = 3: + * + *

    + *
  1. The method scans backward: E9, C2, E7, E6, C1, E4... + *
  2. C2 is identified as the most recent compaction event (end timestamp T=3). + *
  3. E9, E7, E6 are collected as they are newer than T=3. + *
  4. C1 is ignored as we only care about the boundary set by the latest compaction. + *
  5. E4 (T=4) is collected because it is newer than T=3. + *
  6. Scanning stops at E3 as it is covered by C2 (timestamp <= T=3). + *
  7. The initial list of events to summarize: [E9, E7, E6, E4]. + *
  8. After appending the compaction event C2, the list becomes: [E9, E7, E6, E4, C2] + *
  9. Reversing the list: [C2, E4, E6, E7, E9]. + *
  10. Applying retention (keep last 3): E6, E7, E9 are removed from the summary list. + *
  11. Final Output: {@code [C2, E4]}. E4 and the previous summary C2 will be compacted + * together. The new compaction event will cover the range from the start of the included + * compaction event (C2, T=1) to the end of the new events (E4, T=4). + *
+ * + * @param events The list of events to process. + */ + private Maybe> getCompactionEvents(List events) { + Optional count = getLatestPromptTokenCount(events); + if (count.isPresent() && count.get() <= tokenThreshold) { + logger.debug( + "Skipping compaction. Prompt token count {} is within threshold {}", + count.get(), + tokenThreshold); + return Maybe.empty(); + } + + long compactionEndTimestamp = Long.MIN_VALUE; + Event lastCompactionEvent = null; + List eventsToSummarize = new ArrayList<>(); + + // Iterate backwards from the end of the window to summarize. + // We use a single loop to: + // 1. Collect all raw events that happened after the latest compaction. + // 2. Identify the latest compaction event to establish the stop condition (boundary). + ListIterator iter = events.listIterator(events.size()); + while (iter.hasPrevious()) { + Event event = iter.previous(); + + if (!isCompactEvent(event)) { + // Only include events that are strictly after the last compaction range. + if (event.timestamp() > compactionEndTimestamp) { + eventsToSummarize.add(event); + continue; + } else { + // Exit early if we have reached the last event of last compaction range. + break; + } + } + + EventCompaction compaction = event.actions().compaction().orElse(null); + // We use the most recent compaction event to define the time boundary. Any subsequent (older) + // compaction events are ignored. + if (lastCompactionEvent == null) { + compactionEndTimestamp = compaction.endTimestamp(); + lastCompactionEvent = event; + } + } + + // Add the last compaction event to the list of events to summarize. + // This is to ensure that the last compaction event is included in the summary. + if (lastCompactionEvent != null) { + EventCompaction compaction = lastCompactionEvent.actions().compaction().get(); + eventsToSummarize.add( + lastCompactionEvent.toBuilder() + .content(compaction.compactedContent()) + // Use the start timestamp so that the new summary covers the entire range. + .timestamp(compaction.startTimestamp()) + .build()); + } + + Collections.reverse(eventsToSummarize); + + if (count.isEmpty()) { + int estimatedCount = estimateTokenCount(eventsToSummarize); + if (estimatedCount <= tokenThreshold) { + logger.debug( + "Skipping compaction. Estimated prompt token count {} is within threshold {}", + estimatedCount, + tokenThreshold); + return Maybe.empty(); + } + } + + // If there are not enough events to summarize, we can return early. + if (eventsToSummarize.size() <= retentionSize) { + return Maybe.empty(); + } + + // Apply retention: keep the most recent 'retentionSize' events out of the summary. + // We do this by removing them from the list of events to be summarized. + eventsToSummarize + .subList(eventsToSummarize.size() - retentionSize, eventsToSummarize.size()) + .clear(); + return Maybe.just(eventsToSummarize); + } + + private int estimateTokenCount(List events) { + // A common rule of thumb is that one token roughly corresponds to 4 characters of text for + // common English text. + // See https://platform.openai.com/tokenizer + return events.stream().mapToInt(event -> event.stringifyContent().length()).sum() / 4; + } + + private Optional getLatestPromptTokenCount(List events) { + return Lists.reverse(events).stream() + .map(Event::usageMetadata) + .flatMap(Optional::stream) + .map(GenerateContentResponseUsageMetadata::promptTokenCount) + .flatMap(Optional::stream) + .findFirst(); + } + + private static boolean isCompactEvent(Event event) { + return event.actions() != null && event.actions().compaction().isPresent(); + } +} diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 23054b674..07a640c37 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -19,26 +19,44 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; import com.google.adk.JsonBaseModel; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableSource; +import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.FlowableTransformer; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.MaybeSource; +import io.reactivex.rxjava3.core.MaybeTransformer; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleSource; +import io.reactivex.rxjava3.core.SingleTransformer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Supplier; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,8 +70,56 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = + AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); + + private static final AttributeKey GEN_AI_OPERATION_NAME = + AttributeKey.stringKey("gen_ai.operation.name"); + private static final AttributeKey GEN_AI_AGENT_DESCRIPTION = + AttributeKey.stringKey("gen_ai.agent.description"); + private static final AttributeKey GEN_AI_AGENT_NAME = + AttributeKey.stringKey("gen_ai.agent.name"); + private static final AttributeKey GEN_AI_CONVERSATION_ID = + AttributeKey.stringKey("gen_ai.conversation.id"); + private static final AttributeKey GEN_AI_SYSTEM = AttributeKey.stringKey("gen_ai.system"); + private static final AttributeKey GEN_AI_TOOL_CALL_ID = + AttributeKey.stringKey("gen_ai.tool_call.id"); + private static final AttributeKey GEN_AI_TOOL_DESCRIPTION = + AttributeKey.stringKey("gen_ai.tool.description"); + private static final AttributeKey GEN_AI_TOOL_NAME = + AttributeKey.stringKey("gen_ai.tool.name"); + private static final AttributeKey GEN_AI_TOOL_TYPE = + AttributeKey.stringKey("gen_ai.tool.type"); + private static final AttributeKey GEN_AI_REQUEST_MODEL = + AttributeKey.stringKey("gen_ai.request.model"); + private static final AttributeKey GEN_AI_REQUEST_TOP_P = + AttributeKey.doubleKey("gen_ai.request.top_p"); + private static final AttributeKey GEN_AI_REQUEST_MAX_TOKENS = + AttributeKey.longKey("gen_ai.request.max_tokens"); + private static final AttributeKey GEN_AI_USAGE_INPUT_TOKENS = + AttributeKey.longKey("gen_ai.usage.input_tokens"); + private static final AttributeKey GEN_AI_USAGE_OUTPUT_TOKENS = + AttributeKey.longKey("gen_ai.usage.output_tokens"); + + private static final AttributeKey ADK_TOOL_CALL_ARGS = + AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"); + private static final AttributeKey ADK_LLM_REQUEST = + AttributeKey.stringKey("gcp.vertex.agent.llm_request"); + private static final AttributeKey ADK_LLM_RESPONSE = + AttributeKey.stringKey("gcp.vertex.agent.llm_response"); + private static final AttributeKey ADK_INVOCATION_ID = + AttributeKey.stringKey("gcp.vertex.agent.invocation_id"); + private static final AttributeKey ADK_EVENT_ID = + AttributeKey.stringKey("gcp.vertex.agent.event_id"); + private static final AttributeKey ADK_TOOL_RESPONSE = + AttributeKey.stringKey("gcp.vertex.agent.tool_response"); + private static final AttributeKey ADK_SESSION_ID = + AttributeKey.stringKey("gcp.vertex.agent.session_id"); + private static final AttributeKey ADK_DATA = + AttributeKey.stringKey("gcp.vertex.agent.data"); + @SuppressWarnings("NonFinalStaticField") - private static Tracer tracer = GlobalOpenTelemetry.getTracer("com.google.adk"); + private static Tracer tracer = GlobalOpenTelemetry.getTracer("gcp.vertex.agent"); private static final boolean CAPTURE_MESSAGE_CONTENT_IN_SPANS = Boolean.parseBoolean( @@ -61,57 +127,131 @@ public class Tracing { private Tracing() {} + private static Optional getValidCurrentSpan(String methodName) { + Span span = Span.current(); + if (!span.getSpanContext().isValid()) { + log.trace("{}: No valid span in current context.", methodName); + return Optional.empty(); + } + return Optional.of(span); + } + + private static void setInvocationAttributes( + Span span, InvocationContext invocationContext, String eventId) { + span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); + if (eventId != null && !eventId.isEmpty()) { + span.setAttribute(ADK_EVENT_ID, eventId); + } + + if (invocationContext.session() != null && invocationContext.session().id() != null) { + span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); + } else { + log.trace( + "InvocationContext session or session ID is null, cannot set {}", + ADK_SESSION_ID.getKey()); + } + } + + private static void setToolExecutionAttributes(Span span) { + span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); + } + + private static void setJsonAttribute(Span span, AttributeKey key, Object value) { + if (!CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + span.setAttribute(key, "{}"); + return; + } + try { + String json = + (value instanceof String stringValue) + ? stringValue + : JsonBaseModel.getMapper().writeValueAsString(value); + span.setAttribute(key, json); + } catch (JsonProcessingException | RuntimeException e) { + log.warn("Failed to serialize {} to JSON", key.getKey(), e); + span.setAttribute(key, "{\"error\": \"serialization failed\"}"); + } + } + /** Sets the OpenTelemetry instance to be used for tracing. This is for testing purposes only. */ public static void setTracerForTesting(Tracer tracer) { Tracing.tracer = tracer; } + /** + * Sets span attributes immediately available on agent invocation according to OTEL semconv + * version 1.37. + * + * @param span Span on which attributes are set. + * @param agentName Agent name from which attributes are gathered. + * @param agentDescription Agent description from which attributes are gathered. + * @param invocationContext InvocationContext from which attributes are gathered. + */ + public static void traceAgentInvocation( + Span span, String agentName, String agentDescription, InvocationContext invocationContext) { + span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); + span.setAttribute(GEN_AI_AGENT_NAME, agentName); + if (invocationContext.session() != null && invocationContext.session().id() != null) { + span.setAttribute(GEN_AI_CONVERSATION_ID, invocationContext.session().id()); + } + } + /** * Traces tool call arguments. * * @param args The arguments to the tool call. */ - public static void traceToolCall(Map args) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceToolCall: No valid span in current context."); - return; - } + public static void traceToolCall( + String toolName, String toolDescription, String toolType, Map args) { + getValidCurrentSpan("traceToolCall") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - span.setAttribute("gen_ai.system", "com.google.adk"); - try { - span.setAttribute("adk.tool_call_args", JsonBaseModel.getMapper().writeValueAsString(args)); - } catch (JsonProcessingException e) { - log.warn("traceToolCall: Failed to serialize tool call args to JSON", e); - } + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** * Traces tool response event. * - * @param invocationContext The invocation context for the current agent run. * @param eventId The ID of the event. * @param functionResponseEvent The function response event. */ - public static void traceToolResponse( - InvocationContext invocationContext, String eventId, Event functionResponseEvent) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceToolResponse: No valid span in current context."); - return; - } + public static void traceToolResponse(String eventId, Event functionResponseEvent) { + getValidCurrentSpan("traceToolResponse") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); - span.setAttribute("gen_ai.system", "com.google.adk"); - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - span.setAttribute("adk.event_id", eventId); - span.setAttribute("adk.tool_response", functionResponseEvent.toJson()); + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - // Setting empty llm request and response (as the AdkDevServer UI expects these) - span.setAttribute("adk.llm_request", "{}"); - span.setAttribute("adk.llm_response", "{}"); - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); - } + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -156,38 +296,58 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceCallLlm: No valid span in current context."); - return; - } + getValidCurrentSpan("traceCallLlm") + .ifPresent( + span -> { + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest + .model() + .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - span.setAttribute("gen_ai.system", "com.google.adk"); - llmRequest.model().ifPresent(modelName -> span.setAttribute("gen_ai.request.model", modelName)); - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - span.setAttribute("adk.event_id", eventId); + setInvocationAttributes(span, invocationContext, eventId); - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); - } else { - log.trace( - "traceCallLlm: InvocationContext session or session ID is null, cannot set" - + " adk.session_id"); - } + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { - try { - span.setAttribute( - "adk.llm_request", - JsonBaseModel.getMapper().writeValueAsString(buildLlmRequestForTrace(llmRequest))); - span.setAttribute("adk.llm_response", llmResponse.toJson()); - } catch (JsonProcessingException e) { - log.warn("traceCallLlm: Failed to serialize LlmRequest or LlmResponse to JSON", e); - } - } else { - span.setAttribute("adk.llm_request", "{}"); - span.setAttribute("adk.llm_response", "{}"); - } + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent( + topP -> + span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute( + GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> + span.setAttribute( + GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); + }); } /** @@ -199,36 +359,17 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - Span span = Span.current(); - if (span == null || !span.getSpanContext().isValid()) { - log.trace("traceSendData: No valid span in current context."); - return; - } - - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - if (eventId != null && !eventId.isEmpty()) { - span.setAttribute("adk.event_id", eventId); - } - - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); - } + getValidCurrentSpan("traceSendData") + .ifPresent( + span -> { + setInvocationAttributes(span, invocationContext, eventId); - try { - List> dataList = new ArrayList<>(); - if (data != null) { - for (Content content : data) { - if (content != null) { - dataList.add( - JsonBaseModel.getMapper() - .convertValue(content, new TypeReference>() {})); - } - } - } - span.setAttribute("adk.data", JsonBaseModel.toJsonString(dataList)); - } catch (IllegalStateException e) { - log.warn("traceSendData: Failed to serialize data to JSON", e); - } + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** @@ -273,4 +414,142 @@ public static Flowable traceFlowable( span.end(); }); } + + /** + * Returns a transformer that traces the execution of an RxJava stream. + * + * @param spanName The name of the span to create. + * @param The type of the stream. + * @return A TracerProvider that can be used with .compose(). + */ + public static TracerProvider trace(String spanName) { + return new TracerProvider<>(spanName); + } + + /** + * Returns a transformer that traces the execution of an RxJava stream with an explicit parent + * context. + * + * @param spanName The name of the span to create. + * @param parentContext The explicit parent context for the span. + * @param The type of the stream. + * @return A TracerProvider that can be used with .compose(). + */ + public static TracerProvider trace(String spanName, Context parentContext) { + return new TracerProvider(spanName).setParent(parentContext); + } + + /** + * Returns a transformer that traces an agent invocation. + * + * @param spanName The name of the span to create. + * @param agentName The name of the agent. + * @param agentDescription The description of the agent. + * @param invocationContext The invocation context. + * @param The type of the stream. + * @return A TracerProvider configured for agent invocation. + */ + public static TracerProvider traceAgent( + String spanName, + String agentName, + String agentDescription, + InvocationContext invocationContext) { + return new TracerProvider(spanName) + .configure( + span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); + } + + /** + * A transformer that manages an OpenTelemetry span and scope for RxJava streams. + * + * @param The type of the stream. + */ + public static final class TracerProvider + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final String spanName; + private Context explicitParentContext; + private final List> spanConfigurers = new ArrayList<>(); + + private TracerProvider(String spanName) { + this.spanName = spanName; + } + + /** Configures the span created by this transformer. */ + @CanIgnoreReturnValue + public TracerProvider configure(Consumer configurer) { + spanConfigurers.add(configurer); + return this; + } + + /** Sets an explicit parent context for the span created by this transformer. */ + @CanIgnoreReturnValue + public TracerProvider setParent(Context parentContext) { + this.explicitParentContext = parentContext; + return this; + } + + private Context getParentContext() { + return explicitParentContext != null ? explicitParentContext : Context.current(); + } + + private final class TracingLifecycle { + private Span span; + private Scope scope; + + @SuppressWarnings("MustBeClosedChecker") + void start() { + span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); + spanConfigurers.forEach(c -> c.accept(span)); + scope = span.makeCurrent(); + } + + void end() { + if (scope != null) { + scope.close(); + } + if (span != null) { + span.end(); + } + } + } + + @Override + public Publisher apply(Flowable upstream) { + return Flowable.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + + @Override + public SingleSource apply(Single upstream) { + return Single.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + + @Override + public MaybeSource apply(Maybe upstream) { + return Maybe.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + + @Override + public CompletableSource apply(Completable upstream) { + return Completable.defer( + () -> { + TracingLifecycle lifecycle = new TracingLifecycle(); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + }); + } + } } diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index a531361f2..7eabc48c4 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -28,6 +28,7 @@ import com.google.adk.events.Event; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; +import com.google.adk.sessions.State; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -36,6 +37,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -82,15 +84,42 @@ BaseAgent getAgent() { return agent; } + private Optional getInputSchema(BaseAgent agent) { + BaseAgent currentAgent = agent; + while (true) { + if (currentAgent instanceof LlmAgent llmAgent) { + return llmAgent.inputSchema(); + } + List subAgents = currentAgent.subAgents(); + if (subAgents == null || subAgents.isEmpty()) { + return Optional.empty(); + } + // For composite agents, check the first sub-agent. + currentAgent = subAgents.get(0); + } + } + + private Optional getOutputSchema(BaseAgent agent) { + BaseAgent currentAgent = agent; + while (true) { + if (currentAgent instanceof LlmAgent llmAgent) { + return llmAgent.outputSchema(); + } + List subAgents = currentAgent.subAgents(); + if (subAgents == null || subAgents.isEmpty()) { + return Optional.empty(); + } + // For composite agents, check the last sub-agent. + currentAgent = subAgents.get(subAgents.size() - 1); + } + } + @Override public Optional declaration() { FunctionDeclaration.Builder builder = FunctionDeclaration.builder().description(this.description()).name(this.name()); - Optional agentInputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentInputSchema = llmAgent.inputSchema(); - } + Optional agentInputSchema = getInputSchema(agent); if (agentInputSchema.isPresent()) { builder.parameters(agentInputSchema.get()); @@ -109,13 +138,11 @@ public Optional declaration() { public Single> runAsync(Map args, ToolContext toolContext) { if (this.skipSummarization) { - toolContext.setActions(toolContext.actions().toBuilder().skipSummarization(true).build()); + // Mutate EventActions in-place to ensure object references are maintained. + toolContext.actions().setSkipSummarization(true); } - Optional agentInputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentInputSchema = llmAgent.inputSchema(); - } + Optional agentInputSchema = getInputSchema(agent); final Content content; if (agentInputSchema.isPresent()) { @@ -154,7 +181,7 @@ public Single> runAsync(Map args, ToolContex if (lastEvent.actions() != null && lastEvent.actions().stateDelta() != null && !lastEvent.actions().stateDelta().isEmpty()) { - toolContext.state().putAll(lastEvent.actions().stateDelta()); + updateState(lastEvent.actions().stateDelta(), toolContext.state()); } if (outputText.isEmpty()) { @@ -162,10 +189,7 @@ public Single> runAsync(Map args, ToolContex } String output = outputText.get(); - Optional agentOutputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentOutputSchema = llmAgent.outputSchema(); - } + Optional agentOutputSchema = getOutputSchema(agent); if (agentOutputSchema.isPresent()) { return SchemaUtils.validateOutputSchema(output, agentOutputSchema.get()); @@ -174,4 +198,24 @@ public Single> runAsync(Map args, ToolContex } }); } + + /** + * Updates the given state map with the state delta. + * + *

If a value in the delta is {@link State#REMOVED}, the key is removed from the state map. + * Otherwise, the key-value pair is put into the state map. This method does not distinguish + * between session, app, and user state based on key prefixes. + * + * @param state The state map to update. + */ + private void updateState(Map stateDelta, Map state) { + stateDelta.forEach( + (key, value) -> { + if (value == State.REMOVED) { + state.remove(key); + } else { + state.put(key, value); + } + }); + } } diff --git a/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java b/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java index ffd9601a8..6f89754cf 100644 --- a/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java +++ b/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java @@ -28,8 +28,8 @@ import org.slf4j.LoggerFactory; /** - * A built-in tool that is automatically invoked by Gemini 2 models to retrieve search results from - * Google Search. + * A built-in tool that is automatically invoked by Gemini 2 and 3 models to retrieve search results + * from Google Search. * *

This tool operates internally within the model and does not require or perform local code * execution. @@ -76,7 +76,7 @@ public Completable processLlmRequest( updatedToolsBuilder.add( Tool.builder().googleSearchRetrieval(GoogleSearchRetrieval.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); - } else if (model != null && model.startsWith("gemini-2")) { + } else if (model != null && (model.startsWith("gemini-2") || model.startsWith("gemini-3"))) { updatedToolsBuilder.add(Tool.builder().googleSearch(GoogleSearch.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); diff --git a/core/src/main/java/com/google/adk/tools/ToolContext.java b/core/src/main/java/com/google/adk/tools/ToolContext.java index b421a8e58..5192d19ff 100644 --- a/core/src/main/java/com/google/adk/tools/ToolContext.java +++ b/core/src/main/java/com/google/adk/tools/ToolContext.java @@ -35,8 +35,9 @@ private ToolContext( InvocationContext invocationContext, EventActions eventActions, Optional functionCallId, - Optional toolConfirmation) { - super(invocationContext, eventActions); + Optional toolConfirmation, + @Nullable String eventId) { + super(invocationContext, eventActions, eventId); this.functionCallId = functionCallId; this.toolConfirmation = toolConfirmation; } @@ -125,7 +126,8 @@ public Builder toBuilder() { return new Builder(invocationContext) .actions(eventActions) .functionCallId(functionCallId.orElse(null)) - .toolConfirmation(toolConfirmation.orElse(null)); + .toolConfirmation(toolConfirmation.orElse(null)) + .eventId(eventId()); } @Override @@ -148,6 +150,7 @@ public static final class Builder { private EventActions eventActions = EventActions.builder().build(); // Default empty actions private Optional functionCallId = Optional.empty(); private Optional toolConfirmation = Optional.empty(); + private String eventId; private Builder(InvocationContext invocationContext) { this.invocationContext = invocationContext; @@ -171,8 +174,15 @@ public Builder toolConfirmation(ToolConfirmation toolConfirmation) { return this; } + @CanIgnoreReturnValue + public Builder eventId(String eventId) { + this.eventId = eventId; + return this; + } + public ToolContext build() { - return new ToolContext(invocationContext, eventActions, functionCallId, toolConfirmation); + return new ToolContext( + invocationContext, eventActions, functionCallId, toolConfirmation, eventId); } } } diff --git a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java index b1958d56a..3b63429a9 100644 --- a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java +++ b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationClient.java @@ -148,7 +148,7 @@ private void validate() { } } - String generateOpenApiSpec() throws Exception { + String generateOpenApiSpec() throws IOException, InterruptedException { String url = String.format( "https://%s-integrations.googleapis.com/v1/projects/%s/locations/%s:generateOpenApiSpec", @@ -179,7 +179,7 @@ String generateOpenApiSpec() throws Exception { httpClient.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); if (response.statusCode() < 200 || response.statusCode() >= 300) { - throw new Exception("Error fetching OpenAPI spec. Status: " + response.statusCode()); + throw new IOException("Error fetching OpenAPI spec. Status: " + response.statusCode()); } return response.body(); } @@ -343,7 +343,7 @@ ObjectNode getOpenApiSpecForConnection(String toolName, String toolInstructions) return connectorSpec; } - String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws Exception { + String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { @@ -372,7 +372,7 @@ String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) thr } } } - throw new Exception("Could not find operationId for pathUrl: " + pathUrl); + throw new IOException("Could not find operationId for pathUrl: " + pathUrl); } ConnectionsClient createConnectionsClient() { diff --git a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java index bc1357106..be93582e7 100644 --- a/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java +++ b/core/src/main/java/com/google/adk/tools/applicationintegrationtoolset/IntegrationConnectorTool.java @@ -129,7 +129,7 @@ public class IntegrationConnectorTool extends BaseTool { this.credentialsHelper = Preconditions.checkNotNull(credentialsHelper); } - Schema toGeminiSchema(String openApiSchema, String operationId) throws Exception { + Schema toGeminiSchema(String openApiSchema, String operationId) throws IOException { String resolvedSchemaString = getResolvedRequestSchemaByOperationId(openApiSchema, operationId); return Schema.fromJson(resolvedSchemaString); } @@ -148,7 +148,7 @@ public Optional declaration() { .parameters(parametersSchema) .build(); return Optional.of(declaration); - } catch (Exception e) { + } catch (IOException e) { logger.error("Failed to get OpenAPI spec", e); return Optional.empty(); } @@ -175,20 +175,21 @@ public Single> runAsync(Map args, ToolContex try { String response = executeIntegration(args); return ImmutableMap.of("result", response); - } catch (Exception e) { + } catch (IOException | InterruptedException e) { logger.error("Failed to execute integration", e); return ImmutableMap.of("error", e.getMessage()); } }); } - private String executeIntegration(Map args) throws Exception { + private String executeIntegration(Map args) + throws IOException, InterruptedException { String url = String.format("https://integrations.googleapis.com%s", this.pathUrl); String jsonRequestBody; try { jsonRequestBody = objectMapper.writeValueAsString(args); } catch (IOException e) { - throw new Exception("Error converting args to JSON: " + e.getMessage(), e); + throw new IOException("Error converting args to JSON: " + e.getMessage(), e); } Credentials credentials = credentialsHelper.getGoogleCredentials(this.serviceAccountJson); HttpRequest.Builder requestBuilder = @@ -203,7 +204,7 @@ private String executeIntegration(Map args) throws Exception { httpClient.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); if (response.statusCode() < 200 || response.statusCode() >= 300) { - throw new Exception( + throw new IOException( "Error executing integration. Status: " + response.statusCode() + " , Response: " @@ -212,7 +213,7 @@ private String executeIntegration(Map args) throws Exception { return response.body(); } - String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws Exception { + String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { @@ -254,11 +255,11 @@ String getOperationIdFromPathUrl(String openApiSchemaString, String pathUrl) thr } } } - throw new Exception("Could not find operationId for pathUrl: " + pathUrl); + throw new IOException("Could not find operationId for pathUrl: " + pathUrl); } private String getResolvedRequestSchemaByOperationId( - String openApiSchemaString, String operationId) throws Exception { + String openApiSchemaString, String operationId) throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { @@ -268,13 +269,13 @@ private String getResolvedRequestSchemaByOperationId( JsonNode rootNode = objectMapper.readTree(specNode.asText()); JsonNode operationNode = findOperationNodeById(rootNode, operationId); if (operationNode == null) { - throw new Exception("Could not find operation with operationId: " + operationId); + throw new IOException("Could not find operation with operationId: " + operationId); } JsonNode requestSchemaNode = operationNode.path("requestBody").path("content").path("application/json").path("schema"); if (requestSchemaNode.isMissingNode()) { - throw new Exception("Could not find request body schema for operationId: " + operationId); + throw new IOException("Could not find request body schema for operationId: " + operationId); } JsonNode resolvedSchema = resolveRefs(requestSchemaNode, rootNode); @@ -355,7 +356,7 @@ private JsonNode resolveRefs(JsonNode currentNode, JsonNode rootNode) { } private String getOperationDescription(String openApiSchemaString, String operationId) - throws Exception { + throws IOException { JsonNode topLevelNode = objectMapper.readTree(openApiSchemaString); JsonNode specNode = topLevelNode.path("openApiSpec"); if (specNode.isMissingNode() || !specNode.isTextual()) { diff --git a/core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java b/core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java new file mode 100644 index 000000000..c94a18f55 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/AdditionalAdkComponentProvider.java @@ -0,0 +1,53 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.utils; + +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.GoogleMapsTool; +import com.google.adk.tools.GoogleSearchTool; +import com.google.adk.tools.mcp.McpToolset; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Provides ADK components that are part of core. */ +public final class AdditionalAdkComponentProvider implements AdkComponentProvider { + + /** + * Returns tool instances for {@link GoogleSearchTool} and {@link GoogleMapsTool}. + * + * @return a map of tool instances. + */ + @Override + public Map getToolInstances() { + Map toolInstances = new HashMap<>(); + toolInstances.put("google_search", GoogleSearchTool.INSTANCE); + toolInstances.put("google_maps_grounding", GoogleMapsTool.INSTANCE); + return toolInstances; + } + + /** + * Returns toolset classes for {@link McpToolset}. + * + * @return a list of toolset classes. + */ + @Override + public List> getToolsetClasses() { + return Arrays.asList(McpToolset.class); + } +} diff --git a/core/src/main/java/com/google/adk/utils/AdkComponentProvider.java b/core/src/main/java/com/google/adk/utils/AdkComponentProvider.java new file mode 100644 index 000000000..173edbb26 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/AdkComponentProvider.java @@ -0,0 +1,64 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.utils; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; +import java.util.Map; + +/** Service provider interface for ADK components to be registered in {@link ComponentRegistry}. */ +public interface AdkComponentProvider { + + /** + * Returns a list of agent classes to register. + * + * @return a list of agent classes. + */ + default List> getAgentClasses() { + return ImmutableList.of(); + } + + /** + * Returns a list of tool classes to register. + * + * @return a list of tool classes. + */ + default List> getToolClasses() { + return ImmutableList.of(); + } + + /** + * Returns a list of toolset classes to register. + * + * @return a list of toolset classes. + */ + default List> getToolsetClasses() { + return ImmutableList.of(); + } + + /** + * Returns a map of tool instances to register, with tool name as key. + * + * @return a map of tool instances. + */ + default Map getToolInstances() { + return ImmutableMap.of(); + } +} diff --git a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java index 0a9f55b16..3b2d0d14a 100644 --- a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java +++ b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java @@ -23,25 +23,13 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.Callbacks; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.LoopAgent; -import com.google.adk.agents.ParallelAgent; -import com.google.adk.agents.SequentialAgent; -import com.google.adk.tools.AgentTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; -import com.google.adk.tools.ExampleTool; -import com.google.adk.tools.ExitLoopTool; -import com.google.adk.tools.GoogleMapsTool; -import com.google.adk.tools.GoogleSearchTool; -import com.google.adk.tools.LoadArtifactsTool; -import com.google.adk.tools.LongRunningFunctionTool; -import com.google.adk.tools.UrlContextTool; -import com.google.adk.tools.mcp.McpToolset; import java.util.Map; import java.util.Optional; +import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nonnull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -103,46 +91,45 @@ protected ComponentRegistry() { /** Initializes the registry with base pre-wired ADK instances. */ private void initializePreWiredEntries() { - registerAdkAgentClass(LlmAgent.class); - registerAdkAgentClass(LoopAgent.class); - registerAdkAgentClass(ParallelAgent.class); - registerAdkAgentClass(SequentialAgent.class); - - registerAdkToolInstance("google_search", GoogleSearchTool.INSTANCE); - registerAdkToolInstance("load_artifacts", LoadArtifactsTool.INSTANCE); - registerAdkToolInstance("exit_loop", ExitLoopTool.INSTANCE); - registerAdkToolInstance("url_context", UrlContextTool.INSTANCE); - registerAdkToolInstance("google_maps_grounding", GoogleMapsTool.INSTANCE); - - registerAdkToolClass(AgentTool.class); - registerAdkToolClass(LongRunningFunctionTool.class); - registerAdkToolClass(ExampleTool.class); - - registerAdkToolsetClass(McpToolset.class); - // TODO: add all python tools that also exist in Java. + // Core components are registered first. + AdkComponentProvider coreProvider = new CoreAdkComponentProvider(); + registerProvider(coreProvider); + ServiceLoader loader = ServiceLoader.load(AdkComponentProvider.class); + for (AdkComponentProvider provider : loader) { + registerProvider(provider); + } logger.debug("Initialized base pre-wired entries in ComponentRegistry"); } + private void registerProvider(AdkComponentProvider provider) { + provider.getAgentClasses().forEach(this::registerAdkAgentClass); + provider.getToolClasses().forEach(this::registerAdkToolClass); + provider.getToolsetClasses().forEach(this::registerAdkToolsetClass); + provider.getToolInstances().forEach(this::registerAdkToolInstance); + logger.info("Registered components from " + provider.getClass().getName()); + } + private void registerAdkAgentClass(Class agentClass) { registry.put(agentClass.getName(), agentClass); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.agents." + agentClass.getSimpleName(), agentClass); } - private void registerAdkToolInstance(String name, @Nonnull Object toolInstance) { + private void registerAdkToolInstance(String name, BaseTool toolInstance) { registry.put(name, toolInstance); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.tools." + name, toolInstance); } - private void registerAdkToolClass(@Nonnull Class toolClass) { + private void registerAdkToolClass(Class toolClass) { registry.put(toolClass.getName(), toolClass); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.tools." + toolClass.getSimpleName(), toolClass); + registry.put(toolClass.getSimpleName(), toolClass); } - private void registerAdkToolsetClass(@Nonnull Class toolsetClass) { + private void registerAdkToolsetClass(Class toolsetClass) { registry.put(toolsetClass.getName(), toolsetClass); // For python compatibility, also register the name used in ADK Python. registry.put("google.adk.tools." + toolsetClass.getSimpleName(), toolsetClass); diff --git a/core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java b/core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java new file mode 100644 index 000000000..455b2cf95 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/CoreAdkComponentProvider.java @@ -0,0 +1,75 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.utils; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.LoopAgent; +import com.google.adk.agents.ParallelAgent; +import com.google.adk.agents.SequentialAgent; +import com.google.adk.tools.AgentTool; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ExampleTool; +import com.google.adk.tools.ExitLoopTool; +import com.google.adk.tools.LoadArtifactsTool; +import com.google.adk.tools.LongRunningFunctionTool; +import com.google.adk.tools.UrlContextTool; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Provides ADK components that are part of core. */ +public class CoreAdkComponentProvider implements AdkComponentProvider { + + /** + * Returns agent classes for {@link LlmAgent}, {@link LoopAgent}, {@link ParallelAgent} and {@link + * SequentialAgent}. + * + * @return a list of agent classes. + */ + @Override + public List> getAgentClasses() { + return Arrays.asList( + LlmAgent.class, LoopAgent.class, ParallelAgent.class, SequentialAgent.class); + } + + /** + * Returns tool classes for {@link AgentTool}, {@link LongRunningFunctionTool} and {@link + * ExampleTool}. + * + * @return a list of tool classes. + */ + @Override + public List> getToolClasses() { + return Arrays.asList(AgentTool.class, LongRunningFunctionTool.class, ExampleTool.class); + } + + /** + * Returns tool instances for {@link LoadArtifactsTool}, {@link ExitLoopTool} and {@link + * UrlContextTool}. + * + * @return a map of tool instances. + */ + @Override + public Map getToolInstances() { + Map toolInstances = new HashMap<>(); + toolInstances.put("load_artifacts", LoadArtifactsTool.INSTANCE); + toolInstances.put("exit_loop", ExitLoopTool.INSTANCE); + toolInstances.put("url_context", UrlContextTool.INSTANCE); + return toolInstances; + } +} diff --git a/core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider b/core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider new file mode 100644 index 000000000..795480cc8 --- /dev/null +++ b/core/src/main/resources/META-INF/services/com.google.adk.utils.AdkComponentProvider @@ -0,0 +1 @@ +com.google.adk.utils.AdditionalAdkComponentProvider diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 345436826..5e2fa5792 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -27,8 +28,10 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,6 +42,20 @@ public final class BaseAgentTest { private static final String TEST_AGENT_NAME = "testAgent"; private static final String TEST_AGENT_DESCRIPTION = "A test agent"; + private static class ClosableTestAgent extends TestBaseAgent { + final AtomicBoolean closed = new AtomicBoolean(false); + + ClosableTestAgent(String name, String description, List subAgents) { + super(name, description, null, subAgents, null, null); + } + + @Override + public Completable close() { + closed.set(true); + return super.close(); + } + } + @Test public void constructor_setsNameAndDescription() { String name = "testName"; @@ -59,10 +76,10 @@ public void findAgent_returnsCorrectAgent() { TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, null, ImmutableList.of(subAgent), null, null); - assertThat(agent.findAgent("subSubAgent")).isEqualTo(subSubAgent); - assertThat(agent.findAgent("subAgent")).isEqualTo(subAgent); - assertThat(agent.findAgent(TEST_AGENT_NAME)).isEqualTo(agent); - assertThat(agent.findAgent("nonExistent")).isNull(); + assertThat(agent.findAgent("subSubAgent")).hasValue(subSubAgent); + assertThat(agent.findAgent("subAgent")).hasValue(subAgent); + assertThat(agent.findAgent(TEST_AGENT_NAME)).hasValue(agent); + assertThat(agent.findAgent("nonExistent")).isEmpty(); } @Test @@ -77,6 +94,21 @@ public void rootAgent_returnsRootAgent() { assertThat(subSubAgent.rootAgent()).isEqualTo(agent); assertThat(subAgent.rootAgent()).isEqualTo(agent); assertThat(agent.rootAgent()).isEqualTo(agent); + assertThat(subSubAgent.parentAgent()).isEqualTo(subAgent); + assertThat(subAgent.parentAgent()).isEqualTo(agent); + assertThat(agent.parentAgent()).isNull(); + } + + @Test + public void subAgents_returnsSubAgents() { + TestBaseAgent subAgent1 = + new TestBaseAgent("subAgent1", "subAgent1", null, ImmutableList.of(), null, null); + TestBaseAgent subAgent2 = + new TestBaseAgent("subAgent2", "subAgent2", null, ImmutableList.of(), null, null); + TestBaseAgent agent = + new TestBaseAgent( + "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null); + assertThat(agent.subAgents()).containsExactly(subAgent1, subAgent2).inOrder(); } @Test @@ -316,4 +348,283 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalBeforeAgentCallbacks()).containsExactly(bc); assertThat(agent.canonicalAfterAgentCallbacks()).containsExactly(ac); } + + @Test + public void runLive_invokesRunLiveImpl() { + var runLiveCallback = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("live_output")); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + /* beforeAgentCallbacks= */ ImmutableList.of(), + /* afterAgentCallbacks= */ ImmutableList.of(), + runLiveCallback.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(runLiveCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunLiveImplAndAfterCallback() { + var runLiveImpl = TestCallback.returningEmpty(); + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback = TestCallback.returning(callbackContent); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(callbackContent); + assertThat(runLiveImpl.wasCalled()).isFalse(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isFalse(); + } + + @Test + public void runLive_firstBeforeCallbackReturnsContent_skipsSecondBeforeCallback() { + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback1 = TestCallback.returning(callbackContent); + var beforeCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of( + beforeCallback1.asBeforeAgentCallback(), beforeCallback2.asBeforeAgentCallback()), + ImmutableList.of(), + TestCallback.returningEmpty().asRunLiveImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + var unused = agent.runLive(invocationContext).toList().blockingGet(); + assertThat(beforeCallback1.wasCalled()).isTrue(); + assertThat(beforeCallback2.wasCalled()).isFalse(); + } + + @Test + public void + runLive_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunLiveImplAndAfterCallbacks() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_afterCallbackReturnsContent_invokesRunLiveImplAndAfterCallbacksAndReturnsAllContent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returning(afterCallbackContent); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_beforeCallbackMutatesStateAndReturnsEmpty_invokesRunLiveImplAndReturnsStateEvent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + BeforeAgentCallback beforeCallback = + new BeforeAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // State event from before callback + assertThat(results.get(0).content()).isEmpty(); + assertThat(results.get(0).actions().stateDelta()).containsEntry("key", "value"); + // Content event from runLiveImpl + assertThat(results.get(1).content()).hasValue(runLiveImplContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_afterCallbackMutatesStateAndReturnsEmpty_invokesRunLiveImplAndReturnsStateEvent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + AfterAgentCallback afterCallback = + new AfterAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // Content event from runLiveImpl + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + // State event from after callback + assertThat(results.get(1).content()).isEmpty(); + assertThat(results.get(1).actions().stateDelta()).containsEntry("key", "value"); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + } + + @Test + public void runLive_firstAfterCallbackReturnsContent_skipsSecondAfterCallback() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var afterCallback1 = TestCallback.returning(afterCallbackContent); + var afterCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(), + ImmutableList.of( + afterCallback1.asAfterAgentCallback(), afterCallback2.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(afterCallback1.wasCalled()).isTrue(); + assertThat(afterCallback2.wasCalled()).isFalse(); + } + + @Test + public void constructor_invalidName_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, + () -> new TestBaseAgent("invalid name?", "description", null, null, null)); + } + + @Test + public void constructor_userName_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, + () -> new TestBaseAgent("user", "description", null, null, null)); + } + + @Test + public void constructor_duplicateSubAgentNames_throwsIllegalArgumentException() { + TestBaseAgent subAgent1 = new TestBaseAgent("subAgent", "subAgent1", null, null, null); + TestBaseAgent subAgent2 = new TestBaseAgent("subAgent", "subAgent2", null, null, null); + assertThrows( + IllegalArgumentException.class, + () -> + new TestBaseAgent( + "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null)); + } + + @Test + @SuppressWarnings("DoNotCall") + public void fromConfig_throwsUnsupportedOperationException() { + assertThrows(UnsupportedOperationException.class, () -> BaseAgent.fromConfig(null, null)); + } + + @Test + public void close_noSubAgents_completesSuccessfully() { + ClosableTestAgent agent = new ClosableTestAgent("agent", "description", ImmutableList.of()); + agent.close().blockingAwait(); + assertThat(agent.closed.get()).isTrue(); + } + + @Test + public void close_oneLevelSubAgents_closesAllSubAgents() { + ClosableTestAgent subAgent1 = new ClosableTestAgent("sub1", "sub1", ImmutableList.of()); + ClosableTestAgent subAgent2 = new ClosableTestAgent("sub2", "sub2", ImmutableList.of()); + ClosableTestAgent agent = + new ClosableTestAgent("agent", "description", ImmutableList.of(subAgent1, subAgent2)); + + agent.close().blockingAwait(); + + assertThat(agent.closed.get()).isTrue(); + assertThat(subAgent1.closed.get()).isTrue(); + assertThat(subAgent2.closed.get()).isTrue(); + } + + @Test + public void close_twoLevelsSubAgents_closesAllSubAgents() { + ClosableTestAgent subSubAgent = new ClosableTestAgent("subSub", "subSub", ImmutableList.of()); + ClosableTestAgent subAgent = new ClosableTestAgent("sub", "sub", ImmutableList.of(subSubAgent)); + ClosableTestAgent agent = + new ClosableTestAgent("agent", "description", ImmutableList.of(subAgent)); + + agent.close().blockingAwait(); + + assertThat(agent.closed.get()).isTrue(); + assertThat(subAgent.closed.get()).isTrue(); + assertThat(subSubAgent.closed.get()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/agents/CallbacksTest.java b/core/src/test/java/com/google/adk/agents/CallbacksTest.java index f5ba8aafd..11087e6d6 100644 --- a/core/src/test/java/com/google/adk/agents/CallbacksTest.java +++ b/core/src/test/java/com/google/adk/agents/CallbacksTest.java @@ -1057,6 +1057,125 @@ public void handleFunctionCalls_withChainedToolCallbacks_overridesResultAndPasse assertThat(invocationContext.session().state()).containsExactlyEntriesIn(stateAddedByBc2); } + @Test + public void + handleFunctionCalls_withChainedBeforeToolCallbacks_firstModifiesArgsSecondReturnsResponse() { + ImmutableMap originalArgs = ImmutableMap.of("arg1", "val1"); + ImmutableMap modifiedArgsByCb1 = + ImmutableMap.of("arg1", "val1", "arg2", "val2"); + ImmutableMap responseFromCb2 = ImmutableMap.of("result", "from cb2"); + + Callbacks.BeforeToolCallbackSync cb1 = + (invocationContext, tool, input, toolContext) -> { + input.put("arg2", "val2"); + return Optional.empty(); + }; + + Callbacks.BeforeToolCallbackSync cb2 = + (invocationContext, tool, input, toolContext) -> { + if (input.equals(modifiedArgsByCb1)) { + return Optional.of(responseFromCb2); + } + return Optional.empty(); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .beforeToolCallback(ImmutableList.of(cb1, cb2)) + .build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("fc_id") + .name("echo_tool") + .args(originalArgs) + .build()) + .build())) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("echo_tool", new TestUtils.FailingEchoTool())) + .blockingGet(); + + assertThat(getFunctionResponse(functionResponseEvent)).isEqualTo(responseFromCb2); + } + + @Test + public void + handleFunctionCalls_withPluginAndAgentBeforeToolCallbacks_pluginModifiesArgsAgentSeesThem() { + ImmutableMap originalArgs = ImmutableMap.of("arg1", "val1"); + ImmutableMap modifiedArgsByPlugin = + ImmutableMap.of("arg1", "val1", "arg2", "val2"); + ImmutableMap responseFromAgentCb = ImmutableMap.of("result", "from agent cb"); + + Plugin testPlugin = + new Plugin() { + @Override + public String getName() { + return "test_plugin"; + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + toolArgs.put("arg2", "val2"); + return Maybe.empty(); + } + }; + + Callbacks.BeforeToolCallbackSync agentCb = + (invocationContext, tool, input, toolContext) -> { + if (input.equals(modifiedArgsByPlugin)) { + return Optional.of(responseFromAgentCb); + } + return Optional.empty(); + }; + + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .beforeToolCallbackSync(agentCb) + .build(); + + InvocationContext invocationContext = + createInvocationContext(agent).toBuilder() + .pluginManager(new PluginManager(ImmutableList.of(testPlugin))) + .build(); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("fc_id") + .name("echo_tool") + .args(originalArgs) + .build()) + .build())) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls( + invocationContext, + event, + ImmutableMap.of("echo_tool", new TestUtils.FailingEchoTool())) + .blockingGet(); + + assertThat(getFunctionResponse(functionResponseEvent)).isEqualTo(responseFromAgentCb); + } + @Test public void agentRunAsync_withToolCallbacks_inspectsArgsAndReturnsResponse() { TestUtils.EchoTool echoTool = new TestUtils.EchoTool(); diff --git a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java index 4825efca1..4f6ea6104 100644 --- a/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java +++ b/core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java @@ -1161,20 +1161,25 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() String pfx = "test.callbacks."; registry.register( - pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); + pfx + "before_agent_1", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (ctx) -> Maybe.empty()); - registry.register(pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (ctx) -> Maybe.empty()); + pfx + "before_agent_2", (Callbacks.BeforeAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "before_model_1", (Callbacks.BeforeModelCallback) (ctx, req) -> Maybe.empty()); + pfx + "after_agent_1", (Callbacks.AfterAgentCallback) (unusedCtx) -> Maybe.empty()); registry.register( - pfx + "after_model_1", (Callbacks.AfterModelCallback) (ctx, resp) -> Maybe.empty()); + pfx + "before_model_1", + (Callbacks.BeforeModelCallback) (unusedCtx, unusedReq) -> Maybe.empty()); + registry.register( + pfx + "after_model_1", + (Callbacks.AfterModelCallback) (unusedCtx, unusedResp) -> Maybe.empty()); registry.register( pfx + "before_tool_1", - (Callbacks.BeforeToolCallback) (inv, tool, args, toolCtx) -> Maybe.empty()); + (Callbacks.BeforeToolCallback) + (unusedInv, unusedTool, unusedArgs, unusedToolCtx) -> Maybe.empty()); registry.register( pfx + "after_tool_1", - (Callbacks.AfterToolCallback) (inv, tool, args, toolCtx, resp) -> Maybe.empty()); + (Callbacks.AfterToolCallback) + (unusedInv, unusedTool, unusedArgs, unusedToolCtx, unusedResp) -> Maybe.empty()); File configFile = tempFolder.newFile("with_callbacks.yaml"); Files.writeString( @@ -1204,20 +1209,14 @@ public void fromConfig_withConfiguredCallbacks_resolvesCallbacks() assertThat(agent).isInstanceOf(LlmAgent.class); LlmAgent llm = (LlmAgent) agent; - assertThat(agent.beforeAgentCallback()).isPresent(); - assertThat(agent.beforeAgentCallback().get()).hasSize(2); - assertThat(agent.afterAgentCallback()).isPresent(); - assertThat(agent.afterAgentCallback().get()).hasSize(1); + assertThat(agent.beforeAgentCallback()).hasSize(2); + assertThat(agent.afterAgentCallback()).hasSize(1); - assertThat(llm.beforeModelCallback()).isPresent(); - assertThat(llm.beforeModelCallback().get()).hasSize(1); - assertThat(llm.afterModelCallback()).isPresent(); - assertThat(llm.afterModelCallback().get()).hasSize(1); + assertThat(llm.beforeModelCallback()).hasSize(1); + assertThat(llm.afterModelCallback()).hasSize(1); - assertThat(llm.beforeToolCallback()).isPresent(); - assertThat(llm.beforeToolCallback().get()).hasSize(1); - assertThat(llm.afterToolCallback()).isPresent(); - assertThat(llm.afterToolCallback().get()).hasSize(1); + assertThat(llm.beforeToolCallback()).hasSize(1); + assertThat(llm.afterToolCallback()).hasSize(1); } @Test diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 64d2f5bf6..0237261c5 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -17,23 +17,23 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; -import com.google.adk.events.Event; import com.google.adk.memory.BaseMemoryService; +import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; +import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; -import com.google.genai.types.FunctionCall; -import com.google.genai.types.Part; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -145,7 +145,7 @@ public void testBuildWithLiveRequestQueue() { } @Test - public void testCopyOf() { + public void testToBuilder() { InvocationContext originalContext = InvocationContext.builder() .sessionService(mockSessionService) @@ -178,6 +178,24 @@ public void testCopyOf() { assertThat(copiedContext.endInvocation()).isEqualTo(originalContext.endInvocation()); assertThat(copiedContext.activeStreamingTools()) .isEqualTo(originalContext.activeStreamingTools()); + assertThat(copiedContext.callbackContextData()) + .isEqualTo(originalContext.callbackContextData()); + } + + @Test + public void testBuildWithCallbackContextData() { + ConcurrentHashMap data = new ConcurrentHashMap<>(); + data.put("key", "value"); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .callbackContextData(data) + .build(); + + assertThat(context.callbackContextData()).isEqualTo(data); } @Test @@ -404,6 +422,22 @@ public void testEquals_differentValues() { assertThat(context.equals(contextWithDiffAgent)).isFalse(); assertThat(context.equals(contextWithUserContentEmpty)).isFalse(); assertThat(context.equals(contextWithLiveQueuePresent)).isFalse(); + + InvocationContext contextWithDiffCallbackContextData = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(false) + .callbackContextData(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) + .build(); + assertThat(context.equals(contextWithDiffCallbackContextData)).isFalse(); } @Test @@ -453,155 +487,343 @@ public void testHashCode_differentValues() { assertThat(context).isNotEqualTo(contextWithDiffSessionService); assertThat(context).isNotEqualTo(contextWithDiffInvocationId); + + InvocationContext contextWithDiffCallbackContextData = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(false) + .callbackContextData(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) + .build(); + assertThat(context.hashCode()).isNotEqualTo(contextWithDiffCallbackContextData.hashCode()); } @Test - public void isResumable_whenResumabilityConfigIsNotResumable_isFalse() { + public void incrementLlmCallsCount_whenLimitNotExceeded_doesNotThrow() throws Exception { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(false)) + .runConfig(RunConfig.builder().setMaxLlmCalls(2).build()) .build(); - assertThat(context.isResumable()).isFalse(); + + context.incrementLlmCallsCount(); + context.incrementLlmCallsCount(); + // No exception thrown } @Test - public void isResumable_whenResumabilityConfigIsResumable_isTrue() { + public void incrementLlmCallsCount_whenLimitExceeded_throwsException() throws Exception { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) + .runConfig(RunConfig.builder().setMaxLlmCalls(1).build()) .build(); - assertThat(context.isResumable()).isTrue(); + + context.incrementLlmCallsCount(); + LlmCallsLimitExceededException thrown = + Assert.assertThrows( + LlmCallsLimitExceededException.class, () -> context.incrementLlmCallsCount()); + assertThat(thrown).hasMessageThat().contains("limit of 1 exceeded"); } @Test - public void shouldPauseInvocation_whenNotResumable_isFalse() { + public void incrementLlmCallsCount_whenNoLimit_doesNotThrow() throws Exception { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(false)) - .build(); - Event event = - Event.builder() - .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); + .runConfig(RunConfig.builder().setMaxLlmCalls(0).build()) + .build(); + + for (int i = 0; i < 100; i++) { + context.incrementLlmCallsCount(); + } } @Test - public void shouldPauseInvocation_whenResumableAndNoLongRunningToolIds_isFalse() { + public void testSessionGetters() { + Session sessionWithDetails = + Session.builder("test-id").appName("test-app").userId("test-user").build(); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(sessionWithDetails) + .build(); + + assertThat(context.appName()).isEqualTo("test-app"); + assertThat(context.userId()).isEqualTo("test-user"); + } + + @Test + public void testSetEndInvocation() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = - Event.builder() - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); + .build(); + + assertThat(context.endInvocation()).isFalse(); + context.setEndInvocation(true); + assertThat(context.endInvocation()).isTrue(); } @Test - public void shouldPauseInvocation_whenResumableAndNoFunctionCalls_isFalse() { + // Testing deprecated methods. + public void testBranch() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) + .branch("test-branch") .build(); - Event event = Event.builder().longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))).build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); + + assertThat(context.branch()).hasValue("test-branch"); + + context.branch("new-branch"); + assertThat(context.branch()).hasValue("new-branch"); + + context.branch(null); + assertThat(context.branch()).isEmpty(); } @Test - public void shouldPauseInvocation_whenResumableAndNoMatchingFunctionCallId_isFalse() { + // Testing deprecated methods. + public void testDeprecatedCreateMethods() { + InvocationContext context1 = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.ofNullable(userContent)) + .runConfig(runConfig) + .build(); + + assertThat(context1.sessionService()).isEqualTo(mockSessionService); + assertThat(context1.artifactService()).isEqualTo(mockArtifactService); + assertThat(context1.invocationId()).isEqualTo(testInvocationId); + assertThat(context1.agent()).isEqualTo(mockAgent); + assertThat(context1.session()).isEqualTo(session); + assertThat(context1.userContent()).hasValue(userContent); + assertThat(context1.runConfig()).isEqualTo(runConfig); + + InvocationContext context2 = + InvocationContext.create( + mockSessionService, + mockArtifactService, + mockAgent, + session, + liveRequestQueue, + runConfig); + + assertThat(context2.sessionService()).isEqualTo(mockSessionService); + assertThat(context2.artifactService()).isEqualTo(mockArtifactService); + assertThat(context2.agent()).isEqualTo(mockAgent); + assertThat(context2.session()).isEqualTo(session); + assertThat(context2.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context2.runConfig()).isEqualTo(runConfig); + } + + @Test + public void testActiveStreamingTools() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = - Event.builder() - .longRunningToolIds(Optional.of(ImmutableSet.of("fc2"))) - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isFalse(); + .build(); + + assertThat(context.activeStreamingTools()).isEmpty(); + ActiveStreamingTool tool = new ActiveStreamingTool(new LiveRequestQueue()); + context.activeStreamingTools().put("tool1", tool); + assertThat(context.activeStreamingTools()).containsEntry("tool1", tool); } @Test - public void shouldPauseInvocation_whenResumableAndMatchingFunctionCallId_isTrue() { + public void testEventsCompactionConfig() { + EventsCompactionConfig config = new EventsCompactionConfig(5, 2); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .eventsCompactionConfig(config) + .build(); + + assertThat(context.eventsCompactionConfig()).hasValue(config); + } + + @Test + // Testing deprecated methods. + public void testBuilderOptionalParameters() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) .artifactService(mockArtifactService) - .memoryService(mockMemoryService) .agent(mockAgent) .session(session) - .resumabilityConfig(new ResumabilityConfig(true)) - .build(); - Event event = - Event.builder() - .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) - .content( - Content.builder() - .parts( - ImmutableList.of( - Part.builder() - .functionCall( - FunctionCall.builder().name("tool1").id("fc1").build()) - .build())) - .build()) - .build(); - assertThat(context.shouldPauseInvocation(event)).isTrue(); + .liveRequestQueue(Optional.of(liveRequestQueue)) + .branch(Optional.of("test-branch")) + .userContent(Optional.of(userContent)) + .build(); + + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.userContent()).hasValue(userContent); + } + + @Test + // Testing deprecated methods. + public void testDeprecatedConstructor() { + InvocationContext context = + new InvocationContext( + mockSessionService, + mockArtifactService, + mockMemoryService, + pluginManager, + Optional.of(liveRequestQueue), + Optional.of("test-branch"), + testInvocationId, + mockAgent, + session, + Optional.of(userContent), + runConfig, + true); + + assertThat(context.sessionService()).isEqualTo(mockSessionService); + assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); + assertThat(context.pluginManager()).isEqualTo(pluginManager); + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.invocationId()).isEqualTo(testInvocationId); + assertThat(context.agent()).isEqualTo(mockAgent); + assertThat(context.session()).isEqualTo(session); + assertThat(context.userContent()).hasValue(userContent); + assertThat(context.runConfig()).isEqualTo(runConfig); + assertThat(context.endInvocation()).isTrue(); + } + + @Test + // Testing deprecated methods. + public void testDeprecatedConstructor_11params() { + InvocationContext context = + new InvocationContext( + mockSessionService, + mockArtifactService, + mockMemoryService, + Optional.of(liveRequestQueue), + Optional.of("test-branch"), + testInvocationId, + mockAgent, + session, + Optional.of(userContent), + runConfig, + true); + + assertThat(context.sessionService()).isEqualTo(mockSessionService); + assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.invocationId()).isEqualTo(testInvocationId); + assertThat(context.agent()).isEqualTo(mockAgent); + assertThat(context.session()).isEqualTo(session); + assertThat(context.userContent()).hasValue(userContent); + assertThat(context.runConfig()).isEqualTo(runConfig); + assertThat(context.endInvocation()).isTrue(); + } + + @Test + public void build_missingInvocationId_null_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .invocationId(null) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty."); + } + + @Test + public void build_missingInvocationId_empty_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .invocationId("") + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Invocation ID must be non-empty."); + } + + @Test + public void build_missingAgent_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Agent must be set."); + } + + @Test + public void build_missingSession_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Session must be set."); + } + + @Test + public void build_missingSessionService_throwsException() { + InvocationContext.Builder builder = + InvocationContext.builder() + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session); + + IllegalStateException exception = assertThrows(IllegalStateException.class, builder::build); + assertThat(exception).hasMessageThat().isEqualTo("Session service must be set."); } } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 519c90558..ce8be8dfb 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -22,8 +22,10 @@ import static com.google.adk.testing.TestUtils.createTestAgent; import static com.google.adk.testing.TestUtils.createTestAgentBuilder; import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterModelCallback; @@ -34,21 +36,37 @@ import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; import com.google.adk.models.LlmRegistry; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.models.Model; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.adk.telemetry.Tracing; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.opentelemetry.sdk.trace.data.SpanData; +import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -56,6 +74,34 @@ /** Unit tests for {@link LlmAgent}. */ @RunWith(JUnit4.class) public final class LlmAgentTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + + private Tracer originalTracer; + + @Before + public void setup() { + this.originalTracer = Tracing.getTracer(); + Tracing.setTracerForTesting(openTelemetryRule.getOpenTelemetry().getTracer("gcp.vertex.agent")); + } + + @After + public void tearDown() { + Tracing.setTracerForTesting(originalTracer); + } + + private static class ClosableToolset implements BaseToolset { + final AtomicBoolean closed = new AtomicBoolean(false); + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public void close() { + closed.set(true); + } + } @Test public void testRun_withNoCallbacks() { @@ -341,6 +387,13 @@ public void canonicalCallbacks_returnsEmptyListWhenNull() { assertThat(agent.canonicalBeforeToolCallbacks()).isEmpty(); assertThat(agent.canonicalAfterToolCallbacks()).isEmpty(); assertThat(agent.canonicalOnToolErrorCallbacks()).isEmpty(); + + assertThat(agent.beforeModelCallback()).isEmpty(); + assertThat(agent.afterModelCallback()).isEmpty(); + assertThat(agent.onModelErrorCallback()).isEmpty(); + assertThat(agent.beforeToolCallback()).isEmpty(); + assertThat(agent.afterToolCallback()).isEmpty(); + assertThat(agent.onToolErrorCallback()).isEmpty(); } @Test @@ -371,5 +424,218 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalBeforeToolCallbacks()).containsExactly(btc); assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc); assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec); + + assertThat(agent.beforeModelCallback()).containsExactly(bmc); + assertThat(agent.afterModelCallback()).containsExactly(amc); + assertThat(agent.onModelErrorCallback()).containsExactly(omec); + assertThat(agent.beforeToolCallback()).containsExactly(btc); + assertThat(agent.afterToolCallback()).containsExactly(atc); + assertThat(agent.onToolErrorCallback()).containsExactly(otec); + } + + @Test + public void run_sequentialAgents_shareTempStateViaSession() { + // 1. Setup Session Service and Session + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + // 2. Agent 1: runs and produces output "value1" to state "temp:key1" + Content model1Content = Content.fromParts(Part.fromText("value1")); + TestLlm testLlm1 = createTestLlm(createLlmResponse(model1Content)); + LlmAgent agent1 = + createTestAgentBuilder(testLlm1).name("agent1").outputKey("temp:key1").build(); + InvocationContext invocationContext1 = createInvocationContext(agent1, sessionService, session); + + List events1 = agent1.runAsync(invocationContext1).toList().blockingGet(); + assertThat(events1).hasSize(1); + Event event1 = events1.get(0); + assertThat(event1.actions()).isNotNull(); + assertThat(event1.actions().stateDelta()).containsEntry("temp:key1", "value1"); + + // 3. Simulate orchestrator: append event1 to session, updating its state + var unused = sessionService.appendEvent(session, event1).blockingGet(); + assertThat(session.state()).containsEntry("temp:key1", "value1"); + + // 4. Agent 2: uses Instruction.Provider to read "temp:key1" from session state + // and generates an instruction based on it. + TestLlm testLlm2 = + createTestLlm(createLlmResponse(Content.fromParts(Part.fromText("response2")))); + LlmAgent agent2 = + createTestAgentBuilder(testLlm2) + .name("agent2") + .instruction( + new Instruction.Provider( + ctx -> + Single.just( + "Instruction for Agent2 based on Agent1 output: " + + ctx.state().get("temp:key1")))) + .build(); + InvocationContext invocationContext2 = createInvocationContext(agent2, sessionService, session); + List events2 = agent2.runAsync(invocationContext2).toList().blockingGet(); + assertThat(events2).hasSize(1); + + // 5. Verify that agent2's LLM received an instruction containing agent1's output + assertThat(testLlm2.getRequests()).hasSize(1); + LlmRequest request2 = testLlm2.getRequests().get(0); + assertThat(request2.getFirstSystemInstruction().get()) + .contains("Instruction for Agent2 based on Agent1 output: value1"); + } + + @Test + public void close_closesToolsets() throws Exception { + ClosableToolset toolset1 = new ClosableToolset(); + ClosableToolset toolset2 = new ClosableToolset(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .tools(toolset1, toolset2) + .build(); + agent.close().blockingAwait(); + assertThat(toolset1.closed.get()).isTrue(); + assertThat(toolset2.closed.get()).isTrue(); + } + + @Test + public void close_closesToolsetsOnException() throws Exception { + ClosableToolset toolset1 = + new ClosableToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public void close() { + super.close(); + throw new RuntimeException("toolset1 failed to close"); + } + }; + ClosableToolset toolset2 = new ClosableToolset(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .tools(toolset1, toolset2) + .build(); + agent.close().test().assertError(RuntimeException.class); + assertThat(toolset1.closed.get()).isTrue(); + assertThat(toolset2.closed.get()).isTrue(); + } + + @Test + public void runAsync_createsInvokeAgentSpan() throws InterruptedException { + Content modelContent = Content.fromParts(Part.fromText("response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgent(testLlm); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent test agent"))) + .isTrue(); + } + + @Test + public void runAsync_withTools_createsToolSpans() throws InterruptedException { + ImmutableMap echoArgs = ImmutableMap.of("arg", "value"); + Content contentWithFunctionCall = + Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs)); + Content finalResponse = Content.fromParts(Part.fromText("finished")); + TestLlm testLlm = + createTestLlm(createLlmResponse(contentWithFunctionCall), createLlmResponse(finalResponse)); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(new EchoTool()).build(); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); + List llmSpans = findSpansByName(spans, "call_llm"); + List toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]"); + List toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]"); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + String agentSpanId = agentSpan.getSpanContext().getSpanId(); + llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + } + + @Test + public void runAsync_afterToolCallback_propagatesContext() throws InterruptedException { + ImmutableMap echoArgs = ImmutableMap.of("arg", "value"); + Content contentWithFunctionCall = + Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs)); + Content finalResponse = Content.fromParts(Part.fromText("finished")); + TestLlm testLlm = + createTestLlm(createLlmResponse(contentWithFunctionCall), createLlmResponse(finalResponse)); + + AfterToolCallback afterToolCallback = + (invCtx, tool, input, toolCtx, response) -> { + // Verify that the OpenTelemetry context is correctly propagated to the callback. + assertThat(Span.current().getSpanContext().isValid()).isTrue(); + return Maybe.empty(); + }; + + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools(new EchoTool()) + .afterToolCallback(ImmutableList.of(afterToolCallback)) + .build(); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + findSpanByName(spans, "invoke_agent test agent"); + } + + @Test + public void runAsync_withSubAgents_createsSpans() throws InterruptedException { + LlmAgent subAgent = + createTestAgentBuilder(createTestLlm(createTextLlmResponse("sub response"))) + .name("sub-agent") + .build(); + + // Force a transfer to sub-agent using a callback + AfterModelCallback transferCallback = + (ctx, response) -> { + ctx.eventActions().setTransferToAgent(subAgent.name()); + return Maybe.empty(); + }; + + TestLlm testLlm = createTestLlm(createTextLlmResponse("initial")); + LlmAgent agent = + createTestAgentBuilder(testLlm) + .subAgents(subAgent) + .afterModelCallback(ImmutableList.of(transferCallback)) + .build(); + InvocationContext invocationContext = createInvocationContext(agent); + + agent.runAsync(invocationContext).test().await().assertComplete(); + + List spans = openTelemetryRule.getSpans(); + assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent test agent"))) + .isTrue(); + assertThat(spans.stream().anyMatch(s -> s.getName().equals("invoke_agent sub-agent"))).isTrue(); + + List llmSpans = findSpansByName(spans, "call_llm"); + assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent + } + + private List findSpansByName(List spans, String name) { + return spans.stream().filter(s -> s.getName().equals(name)).toList(); + } + + @CanIgnoreReturnValue + private SpanData findSpanByName(List spans, String name) { + return spans.stream() + .filter(s -> s.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new AssertionError("Span not found: " + name)); } } diff --git a/core/src/test/java/com/google/adk/agents/RunConfigTest.java b/core/src/test/java/com/google/adk/agents/RunConfigTest.java index 7b6e7558f..1c416aa72 100644 --- a/core/src/test/java/com/google/adk/agents/RunConfigTest.java +++ b/core/src/test/java/com/google/adk/agents/RunConfigTest.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; import com.google.genai.types.AudioTranscriptionConfig; @@ -114,4 +115,11 @@ public void testInputAudioTranscriptionOnly() { assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.BIDI); assertThat(runConfig.responseModalities()).containsExactly(new Modality(Modality.Known.AUDIO)); } + + @Test + public void testMaxLlmCalls_integerMaxValue_throwsIllegalArgumentException() { + assertThrows( + IllegalArgumentException.class, + () -> RunConfig.builder().setMaxLlmCalls(Integer.MAX_VALUE).build()); + } } diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 1df66c36d..40493bf3a 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -76,6 +77,7 @@ private Blob mockBlob(String name, String contentType, byte[] content) { when(blob.exists()).thenReturn(true); BlobId blobId = BlobId.of(BUCKET_NAME, name); when(blob.getBlobId()).thenReturn(blobId); + when(blob.getBucket()).thenReturn(BUCKET_NAME); return blob; } @@ -89,6 +91,8 @@ public void save_firstVersion_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -109,6 +113,8 @@ public void save_subsequentVersion_savesCorrectly() { Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1}); when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0)); + Blob savedBlob = mockBlob(expectedBlobNameV1, "image/png", new byte[] {4, 5}); + when(mockStorage.create(eq(expectedBlobInfoV1), eq(new byte[] {4, 5}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -126,6 +132,8 @@ public void save_userNamespace_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/json").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/json", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME, artifact).blockingGet(); @@ -330,7 +338,36 @@ public void listVersions_noVersions_returnsEmptyList() { assertThat(versions).isEmpty(); } + @Test + public void saveAndReloadArtifact_savesAndReturnsFileData() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "application/octet-stream"); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + BlobId expectedBlobId = BlobId.of(BUCKET_NAME, expectedBlobName); + BlobInfo expectedBlobInfo = + BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); + + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + + assertThat(result).isPresent(); + assertThat(result.get().fileData()).isPresent(); + assertThat(result.get().fileData().get().fileUri()) + .hasValue("gs://" + BUCKET_NAME + "/" + expectedBlobName); + assertThat(result.get().fileData().get().mimeType()).hasValue("application/octet-stream"); + verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3})); + } + private static Optional asOptional(Maybe maybe) { return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); } + + private static Optional asOptional(Single single) { + return Optional.of(single.blockingGet()); + } } diff --git a/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java new file mode 100644 index 000000000..4cb493277 --- /dev/null +++ b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.artifacts; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InMemoryArtifactService}. */ +@RunWith(JUnit4.class) +public class InMemoryArtifactServiceTest { + + private static final String APP_NAME = "test-app"; + private static final String USER_ID = "test-user"; + private static final String SESSION_ID = "test-session"; + private static final String FILENAME = "test-file.txt"; + + private InMemoryArtifactService service; + + @Before + public void setUp() { + service = new InMemoryArtifactService(); + } + + @Test + public void saveArtifact_savesAndReturnsVersion() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + int version = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); + assertThat(version).isEqualTo(0); + } + + @Test + public void loadArtifact_loadsLatest() { + Part artifact1 = Part.fromBytes(new byte[] {1}, "text/plain"); + Part artifact2 = Part.fromBytes(new byte[] {1, 2}, "text/plain"); + var unused1 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact1).blockingGet(); + var unused2 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact2).blockingGet(); + Optional result = + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + assertThat(result).hasValue(artifact2); + } + + @Test + public void saveAndReloadArtifact_reloadsArtifact() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + assertThat(result).hasValue(artifact); + } + + private static Optional asOptional(Maybe maybe) { + return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); + } + + private static Optional asOptional(Single single) { + return Optional.of(single.blockingGet()); + } +} diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index a9e3693d5..94cd399df 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -18,7 +18,9 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; import java.util.concurrent.ConcurrentHashMap; @@ -43,7 +45,11 @@ public final class EventActionsTest { @Test public void toBuilder_createsBuilderWithSameValues() { EventActions eventActionsWithSkipSummarization = - EventActions.builder().skipSummarization(true).compaction(COMPACTION).build(); + EventActions.builder() + .skipSummarization(true) + .compaction(COMPACTION) + .deletedArtifactIds(ImmutableSet.of("d1")) + .build(); EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build(); @@ -57,7 +63,8 @@ public void merge_mergesAllFields() { EventActions.builder() .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", 1))) + .deletedArtifactIds(ImmutableSet.of("deleted1")) .requestedAuthConfigs( new ConcurrentHashMap<>( ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) @@ -68,7 +75,8 @@ public void merge_mergesAllFields() { EventActions eventActions2 = EventActions.builder() .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2"))) - .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", 2))) + .deletedArtifactIds(ImmutableSet.of("deleted2")) .transferToAgent("agentId") .escalate(true) .requestedAuthConfigs( @@ -76,14 +84,15 @@ public void merge_mergesAllFields() { ImmutableMap.of("config2", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) .requestedToolConfirmations( new ConcurrentHashMap<>(ImmutableMap.of("tool2", TOOL_CONFIRMATION))) - .endInvocation(true) + .endOfAgent(true) .build(); EventActions merged = eventActions1.toBuilder().merge(eventActions2).build(); assertThat(merged.skipSummarization()).hasValue(true); assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2"); - assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART); + assertThat(merged.artifactDelta()).containsExactly("artifact1", 1, "artifact2", 2); + assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2"); assertThat(merged.transferToAgent()).hasValue("agentId"); assertThat(merged.escalate()).hasValue(true); assertThat(merged.requestedAuthConfigs()) @@ -94,7 +103,31 @@ public void merge_mergesAllFields() { new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))); assertThat(merged.requestedToolConfirmations()) .containsExactly("tool1", TOOL_CONFIRMATION, "tool2", TOOL_CONFIRMATION); - assertThat(merged.endInvocation()).hasValue(true); + assertThat(merged.endOfAgent()).isTrue(); assertThat(merged.compaction()).hasValue(COMPACTION); } + + @Test + public void removeStateByKey_marksKeyAsRemoved() { + EventActions eventActions = new EventActions(); + eventActions.stateDelta().put("key1", "value1"); + eventActions.removeStateByKey("key1"); + + assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); + } + + @Test + public void jsonSerialization_works() throws Exception { + EventActions eventActions = + EventActions.builder() + .deletedArtifactIds(ImmutableSet.of("d1", "d2")) + .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))) + .build(); + + String json = eventActions.toJson(); + EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + + assertThat(deserialized).isEqualTo(eventActions); + assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); + } } diff --git a/core/src/test/java/com/google/adk/events/EventTest.java b/core/src/test/java/com/google/adk/events/EventTest.java index f443abee5..cbfb6ef0b 100644 --- a/core/src/test/java/com/google/adk/events/EventTest.java +++ b/core/src/test/java/com/google/adk/events/EventTest.java @@ -45,9 +45,7 @@ public final class EventTest { EventActions.builder() .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) - .artifactDelta( - new ConcurrentHashMap<>( - ImmutableMap.of("artifact_key", Part.builder().text("artifact_value").build()))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact_key", 1))) .transferToAgent("agent_id") .escalate(true) .requestedAuthConfigs( @@ -191,4 +189,67 @@ public void event_json_serialization_works() throws Exception { Event deserializedEvent = Event.fromJson(json); assertThat(deserializedEvent).isEqualTo(EVENT); } + + @Test + public void finalResponse_returnsTrueIfNoToolCalls() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .build(); + assertThat(event.finalResponse()).isTrue(); + } + + @Test + public void finalResponse_returnsFalseIfToolCalls() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .build(); + assertThat(event.finalResponse()).isFalse(); + } + + @Test + public void finalResponse_isTrueForEventWithTextContent() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .longRunningToolIds(ImmutableSet.of("tool1")) + .build(); + assertThat(event.finalResponse()).isTrue(); + } + + @Test + public void finalResponse_isFalseForEventWithToolCallAndLongRunningToolId() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .longRunningToolIds(ImmutableSet.of("tool1")) + .build(); + assertThat(event.finalResponse()).isFalse(); + } + + @Test + public void finalResponse_returnsTrueIfSkipSummarization() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .actions(EventActions.builder().skipSummarization(true).build()) + .build(); + assertThat(event.finalResponse()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 5f4932a89..657d1c670 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -25,6 +25,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; +import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; @@ -42,6 +43,7 @@ import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Map; @@ -414,6 +416,82 @@ public void run_requestProcessorsAreCalledExactlyOnce() { assertThat(processor2CallCount.get()).isEqualTo(1); } + @Test + public void run_sharingcallbackContextDataBetweenCallbacks() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + ctx.invocationContext().callbackContextData().put("key", "value_from_before"); + return Maybe.empty(); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + String value = (String) ctx.invocationContext().callbackContextData().get("key"); + LlmResponse modifiedResp = + resp.toBuilder().content(Content.fromParts(Part.fromText("Saw: " + value))).build(); + return Maybe.just(modifiedResp); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .build()); + + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).stringifyContent()).isEqualTo("Saw: value_from_before"); + } + + @Test + public void run_sharingcallbackContextDataAcrossContextCopies() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + ctx.invocationContext().callbackContextData().put("key", "value_from_before"); + return Maybe.empty(); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + String value = (String) ctx.invocationContext().callbackContextData().get("key"); + LlmResponse modifiedResp = + resp.toBuilder().content(Content.fromParts(Part.fromText("Saw: " + value))).build(); + return Maybe.just(modifiedResp); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + new BaseLlmFlow(ImmutableList.of(), ImmutableList.of()) { + @Override + public Flowable run(InvocationContext context) { + // Force a context copy + InvocationContext copiedContext = context.toBuilder().build(); + return super.run(copiedContext); + } + }; + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).stringifyContent()).isEqualTo("Saw: value_from_before"); + } + private static BaseLlmFlow createBaseLlmFlowWithoutProcessors() { return createBaseLlmFlow(ImmutableList.of(), ImmutableList.of()); } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java b/core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java new file mode 100644 index 000000000..3ceba5641 --- /dev/null +++ b/core/src/test/java/com/google/adk/flows/llmflows/CompactionTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.flows.llmflows; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.adk.summarizer.BaseEventSummarizer; +import com.google.adk.summarizer.EventsCompactionConfig; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CompactionTest { + + private InvocationContext context; + private LlmRequest request; + private Session session; + private BaseSessionService sessionService; + private BaseEventSummarizer summarizer; + + @Before + public void setUp() { + context = mock(InvocationContext.class); + request = LlmRequest.builder().build(); + session = Session.builder("test-session").build(); + sessionService = mock(BaseSessionService.class); + summarizer = mock(BaseEventSummarizer.class); + + when(context.session()).thenReturn(session); + when(context.sessionService()).thenReturn(sessionService); + } + + @Test + public void processRequest_noConfig_doesNothing() { + when(context.eventsCompactionConfig()).thenReturn(Optional.empty()); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + verify(sessionService, never()).appendEvent(any(), any()); + } + + @Test + public void processRequest_withConfig_triggersCompaction() { + // Setup config with threshold 100 + EventsCompactionConfig config = new EventsCompactionConfig(5, 1, summarizer, 100, 2); + when(context.eventsCompactionConfig()).thenReturn(Optional.of(config)); + + // Setup events with usage > 100 to trigger compaction + Event event1 = mock(Event.class); + Event event2 = mock(Event.class); + Event event3 = mock(Event.class); + when(event3.usageMetadata()) + .thenReturn( + Optional.of( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(200).build())); + + session = + Session.builder("test-session").events(ImmutableList.of(event1, event2, event3)).build(); + when(context.session()).thenReturn(session); + + // Summarizer mock + Event summaryEvent = mock(Event.class); + when(summarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(sessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + // Verify compaction happened and result was appended + verify(sessionService).appendEvent(eq(session), eq(summaryEvent)); + } + + @Test + public void processRequest_withConfig_skipsCompactionIfBelowThreshold() { + // Setup config with threshold 500 + EventsCompactionConfig config = new EventsCompactionConfig(5, 1, summarizer, 500, 2); + when(context.eventsCompactionConfig()).thenReturn(Optional.of(config)); + + // Setup events with usage 200 (below 500) + Event event3 = mock(Event.class); + when(event3.usageMetadata()) + .thenReturn( + Optional.of( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(200).build())); + + session = Session.builder("test-session").events(ImmutableList.of(event3)).build(); + when(context.session()).thenReturn(session); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + // Verify NO compaction + verify(sessionService, never()).appendEvent(any(), any()); + } + + @Test + public void processRequest_withConfig_nullRetentionSize_doesNothing() { + // Setup config with retentionSize = null + EventsCompactionConfig config = new EventsCompactionConfig(5, 1, summarizer, 100, null); + when(context.eventsCompactionConfig()).thenReturn(Optional.of(config)); + + Compaction compaction = new Compaction(); + compaction + .processRequest(context, request) + .test() + .assertNoErrors() + .assertValue(r -> r.updatedRequest() == request); + + // Verify NO compaction and session.events() is not called + verify(sessionService, never()).appendEvent(any(), any()); + verify(context, never()).session(); + } +} diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 85895088f..b5df658ba 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -28,6 +28,7 @@ import com.google.adk.events.EventCompaction; import com.google.adk.models.LlmRequest; import com.google.adk.models.Model; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -35,7 +36,6 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -53,7 +53,8 @@ public final class ContentsTest { private static final String AGENT = "agent"; private static final String OTHER_AGENT = "other_agent"; - private final Contents contentsProcessor = new Contents(); + private static final Contents contentsProcessor = new Contents(); + private static final InMemorySessionService sessionService = new InMemorySessionService(); @Test public void rearrangeLatest_emptyList_returnsEmptyList() { @@ -203,9 +204,11 @@ public void rearrangeHistory_asyncFR_returnsRearrangedList() { public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { Event fcEvent = createFunctionCallEvent("fc1", "tool1", "call1"); Event frEvent1 = - createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "running")); + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); Event frEvent2 = - createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "running")); + Event frEvent3 = + createFunctionResponseEvent("fr3", "tool1", "call1", ImmutableMap.of("status", "done")); ImmutableList inputEvents = ImmutableList.of( createUserEvent("u1", "Query"), @@ -213,17 +216,75 @@ public void rearrangeHistory_multipleFRsForSameFC_returnsMergedFR() { createUserEvent("u2", "Wait"), frEvent1, createUserEvent("u3", "Done?"), - frEvent2); + frEvent2, + frEvent3, + createUserEvent("u4", "Follow up query")); List result = runContentsProcessor(inputEvents); - assertThat(result).hasSize(3); // u1, fc1, merged_fr + assertThat(result).hasSize(6); // u1, fc1, merged_fr, u2, u3, u4 assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); - assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check merged event + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // Check fcEvent Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(1); assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) - .containsExactly("status", "done"); // Last FR wins + .containsExactly("status", "done"); // Last FR wins (frEvent3) + assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 + assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 + assertThat(result.get(5)).isEqualTo(inputEvents.get(7).content().get()); // u4 + } + + @Test + public void rearrangeHistory_multipleFRsForMultipleFC_returnsMergedFR() { + Event fcEvent1 = createFunctionCallEvent("fc1", "tool1", "call1"); + Event fcEvent2 = createFunctionCallEvent("fc2", "tool1", "call2"); + + Event frEvent1 = + createFunctionResponseEvent("fr1", "tool1", "call1", ImmutableMap.of("status", "pending")); + Event frEvent2 = + createFunctionResponseEvent("fr2", "tool1", "call1", ImmutableMap.of("status", "done")); + + Event frEvent3 = + createFunctionResponseEvent("fr3", "tool1", "call2", ImmutableMap.of("status", "pending")); + Event frEvent4 = + createFunctionResponseEvent("fr4", "tool1", "call2", ImmutableMap.of("status", "done")); + + ImmutableList inputEvents = + ImmutableList.of( + createUserEvent("u1", "I"), + fcEvent1, + createUserEvent("u2", "am"), + frEvent1, + createUserEvent("u3", "waiting"), + frEvent2, + createUserEvent("u4", "for"), + fcEvent2, + createUserEvent("u5", "you"), + frEvent3, + createUserEvent("u6", "to"), + frEvent4, + createUserEvent("u7", "Follow up query")); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).hasSize(11); // u1, fc1, frEvent2, u2, u3, u4, fc2, frEvent4, u5, u6, u7 + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); // u1 + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); // fc1 + Content mergedContent = result.get(2); + assertThat(mergedContent.parts().get()).hasSize(1); + assertThat(mergedContent.parts().get().get(0).functionResponse().get().response().get()) + .containsExactly("status", "done"); // Last FR wins (frEvent2) + assertThat(result.get(3)).isEqualTo(inputEvents.get(2).content().get()); // u2 + assertThat(result.get(4)).isEqualTo(inputEvents.get(4).content().get()); // u3 + assertThat(result.get(5)).isEqualTo(inputEvents.get(6).content().get()); // u4 + assertThat(result.get(6)).isEqualTo(inputEvents.get(7).content().get()); // fc2 + Content mergedContent2 = result.get(7); + assertThat(mergedContent2.parts().get()).hasSize(1); + assertThat(mergedContent2.parts().get().get(0).functionResponse().get().response().get()) + .containsExactly("status", "done"); // Last FR wins (frEvent4) + assertThat(result.get(8)).isEqualTo(inputEvents.get(8).content().get()); // u5 + assertThat(result.get(9)).isEqualTo(inputEvents.get(10).content().get()); // u6 + assertThat(result.get(10)).isEqualTo(inputEvents.get(12).content().get()); // u7 } @Test @@ -577,6 +638,105 @@ public void processRequest_compactionWithUncompactedEventsBetween() { .containsExactly("content 3", "Summary 1-2"); } + @Test + public void processRequest_rollingSummary_removesRedundancy() { + // Scenario: Rolling summary where a later summary covers a superset of the time range. + // Input: [E1(1), C1(Cover 1-1), E3(3), C2(Cover 1-3)] + // Expected: [C2] + // Explanation: C2 covers the range [1, 3], which includes the range covered by C1 [1, 1]. + // Therefore, C1 is redundant. E1 and E3 are also covered by C2. + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createCompactedEvent(1, 1, "C1"), + createUserEvent("e3", "E3", "inv3", 3), + createCompactedEvent(1, 3, "C2")); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("C2"); + } + + @Test + public void processRequest_rollingSummaryWithRetention() { + // Input: with retention size 3: [E1, E2, E3, E4, C1(Cover 1-1), E6, E7, C2(Cover 1-3), E9] + // Expected: [C2, E4, E6, E7, E9] + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createUserEvent("e2", "E2", "inv2", 2), + createUserEvent("e3", "E3", "inv3", 3), + createUserEvent("e4", "E4", "inv4", 4), + createCompactedEvent(1, 1, "C1"), + createUserEvent("e6", "E6", "inv6", 6), + createUserEvent("e7", "E7", "inv7", 7), + createCompactedEvent(1, 3, "C2"), + createUserEvent("e9", "E9", "inv9", 9)); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("C2", "E4", "E6", "E7", "E9"); + } + + @Test + public void processRequest_rollingSummary_preservesUncoveredHistory() { + // Input: [E1(1), E2(2), E3(3), E4(4), C1(2-2), E6(6), E7(7), C2(2-3), E9(9)] + // Expected: [E1, C2, E4, E6, E7, E9] + // E1 is before C1/C2 range, so it is preserved. + // C1 (2-2) is covered by C2 (2-3), so C1 is removed. + // E2, E3 are covered by C2. + // E4, E6, E7, E9 are retained. + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createUserEvent("e2", "E2", "inv2", 2), + createUserEvent("e3", "E3", "inv3", 3), + createUserEvent("e4", "E4", "inv4", 4), + createCompactedEvent(2, 2, "C1"), + createUserEvent("e6", "E6", "inv6", 6), + createUserEvent("e7", "E7", "inv7", 7), + createCompactedEvent(2, 3, "C2"), + createUserEvent("e9", "E9", "inv9", 9)); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("E1", "C2", "E4", "E6", "E7", "E9"); + } + + @Test + public void processRequest_slidingWindow_preservesOverlappingCompactions() { + // Case 1: Sliding Window + Retention + // Input: [E1(1), E2(2), E3(3), C1(1-2), E4(5), C2(2-3), E5(7)] + // Overlap: C1 and C2 overlap at 2. C1 is NOT redundant (start 1 < start 2). + // Expected: [C1, C2, E4, E5] + // E1(1) covered by C1. + // E2(2) covered by C1 (and C2). + // E3(3) covered by C2. + // E4(5) retained. + // E5(7) retained. + ImmutableList events = + ImmutableList.of( + createUserEvent("e1", "E1", "inv1", 1), + createUserEvent("e2", "E2", "inv2", 2), + createUserEvent("e3", "E3", "inv3", 3), + createCompactedEvent(1, 2, "C1"), + createUserEvent("e4", "E4", "inv4", 5), + createCompactedEvent(2, 3, "C2"), + createUserEvent("e5", "E5", "inv5", 7)); + + List contents = runContentsProcessor(events); + assertThat(contents) + .comparingElementsUsing( + transforming((Content c) -> c.parts().get().get(0).text().get(), "content text")) + .containsExactly("C1", "C2", "E4", "E5"); + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id) @@ -731,16 +891,14 @@ private List runContentsProcessorWithIncludeContents( List events, LlmAgent.IncludeContents includeContents) { LlmAgent agent = LlmAgent.builder().name(AGENT).includeContents(includeContents).build(); Session session = - Session.builder("test-session") - .appName("test-app") - .userId("test-user") - .events(new ArrayList<>(events)) - .build(); + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + session.events().addAll(events); InvocationContext context = InvocationContext.builder() .invocationId("test-invocation") .agent(agent) .session(session) + .sessionService(sessionService) .build(); LlmRequest initialRequest = LlmRequest.builder().build(); @@ -760,16 +918,14 @@ private List runContentsProcessorWithModelName(List events, Stri Mockito.doReturn(model).when(agent).resolvedModel(); Session session = - Session.builder("test-session") - .appName("test-app") - .userId("test-user") - .events(new ArrayList<>(events)) - .build(); + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + session.events().addAll(events); InvocationContext context = InvocationContext.builder() .invocationId("test-invocation") .agent(agent) .session(session) + .sessionService(sessionService) .build(); LlmRequest initialRequest = LlmRequest.builder().build(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java index 4a6216029..7d1615dc2 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java @@ -24,6 +24,8 @@ import com.google.adk.examples.BaseExampleProvider; import com.google.adk.examples.Example; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; @@ -35,6 +37,8 @@ @RunWith(JUnit4.class) public final class ExamplesTest { + private static final InMemorySessionService sessionService = new InMemorySessionService(); + private static class TestExampleProvider implements BaseExampleProvider { @Override public List getExamples(String query) { @@ -55,6 +59,8 @@ public void processRequest_withExampleProvider_addsExamplesToInstructions() { InvocationContext context = InvocationContext.builder() .invocationId("invocation1") + .session(Session.builder("session1").build()) + .sessionService(sessionService) .agent(agent) .userContent(Content.fromParts(Part.fromText("what is up?"))) .runConfig(RunConfig.builder().build()) @@ -76,6 +82,8 @@ public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructio InvocationContext context = InvocationContext.builder() .invocationId("invocation1") + .session(Session.builder("session1").build()) + .sessionService(sessionService) .agent(agent) .userContent(Content.fromParts(Part.fromText("what is up?"))) .runConfig(RunConfig.builder().build()) diff --git a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java index 26d456b70..55adeb39e 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java @@ -27,6 +27,7 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.plugins.PluginManager; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; @@ -60,6 +61,7 @@ public class RequestConfirmationLlmRequestProcessorTest { Optional.of(ORIGINAL_FUNCTION_CALL_ARGS))); private static final FunctionCall FUNCTION_CALL = FunctionCall.builder().id(FUNCTION_CALL_ID).name(ECHO_TOOL_NAME).args(ARGS).build(); + private static final InMemorySessionService sessionService = new InMemorySessionService(); private static final Event REQUEST_CONFIRMATION_EVENT = Event.builder() @@ -93,7 +95,7 @@ public void runAsync_withConfirmation_callsOriginalFunction() { .events(ImmutableList.of(REQUEST_CONFIRMATION_EVENT, USER_CONFIRMATION_EVENT)) .build(); - InvocationContext context = createInvocationContext(agent, session); + InvocationContext context = buildInvocationContext(agent, session); RequestProcessor.RequestProcessingResult result = processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); @@ -132,7 +134,7 @@ public void runAsync_withConfirmationAndToolAlreadyCalled_doesNotCallOriginalFun REQUEST_CONFIRMATION_EVENT, USER_CONFIRMATION_EVENT, toolResponseEvent)) .build(); - InvocationContext context = createInvocationContext(agent, session); + InvocationContext context = buildInvocationContext(agent, session); RequestProcessor.RequestProcessingResult result = processor.processRequest(context, LlmRequest.builder().build()).blockingGet(); @@ -149,7 +151,7 @@ public void runAsync_noEvents_empty() { assertThat( processor .processRequest( - createInvocationContext(agent, session), LlmRequest.builder().build()) + buildInvocationContext(agent, session), LlmRequest.builder().build()) .blockingGet() .events()) .isEmpty(); @@ -164,18 +166,19 @@ public void runAsync_noUserConfirmationEvent_empty() { assertThat( processor .processRequest( - createInvocationContext(agent, session), LlmRequest.builder().build()) + buildInvocationContext(agent, session), LlmRequest.builder().build()) .blockingGet() .events()) .isEmpty(); } - private static InvocationContext createInvocationContext(LlmAgent agent, Session session) { + private static InvocationContext buildInvocationContext(LlmAgent agent, Session session) { return InvocationContext.builder() .pluginManager(new PluginManager()) .invocationId(InvocationContext.newInvocationContextId()) .agent(agent) .session(session) + .sessionService(sessionService) .build(); } diff --git a/core/src/main/java/com/google/adk/agents/BaseAgentState.java b/core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java similarity index 53% rename from core/src/main/java/com/google/adk/agents/BaseAgentState.java rename to core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java index dedcb93ab..ccb10a3a7 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgentState.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/SingleFlowTest.java @@ -14,26 +14,22 @@ * limitations under the License. */ -package com.google.adk.agents; +package com.google.adk.flows.llmflows; -import com.google.adk.JsonBaseModel; +import static com.google.common.truth.Truth.assertThat; -/** Base class for all agent states. */ -public class BaseAgentState extends JsonBaseModel { +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; - protected BaseAgentState() {} +@RunWith(JUnit4.class) +public final class SingleFlowTest { - /** Returns a new {@link Builder} for creating {@link BaseAgentState} instances. */ - public static Builder builder() { - return new Builder(); - } - - /** Builder for {@link BaseAgentState}. */ - public static class Builder { - private Builder() {} - - public BaseAgentState build() { - return new BaseAgentState(); - } + @Test + public void requestProcessors_containsCompaction() { + boolean hasCompaction = + SingleFlow.REQUEST_PROCESSORS.stream() + .anyMatch(processor -> processor instanceof Compaction); + assertThat(hasCompaction).isTrue(); } } diff --git a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java index 5d70cc449..a3ac09fe5 100644 --- a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.LiveServerContent; import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerSetupComplete; @@ -28,6 +29,8 @@ import com.google.genai.types.LiveServerToolCallCancellation; import com.google.genai.types.Part; import com.google.genai.types.UsageMetadata; +import io.reactivex.rxjava3.observers.TestObserver; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -45,8 +48,13 @@ public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() { .build(); LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); + TestObserver testObserver = new TestObserver<>(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.content()).isPresent(); assertThat(response.content().get().text()).isEqualTo("Model response"); @@ -66,8 +74,13 @@ public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.interrupted()).hasValue(false); assertThat(response.turnComplete()).hasValue(false); } @@ -82,8 +95,13 @@ public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.interrupted()).isEmpty(); assertThat(response.turnComplete()).hasValue(true); } @@ -98,8 +116,13 @@ public void convertToServerResponse_withTurnCompleteTrue_mapsPartialFalse() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.partial()).hasValue(false); assertThat(response.turnComplete()).hasValue(true); } @@ -114,8 +137,13 @@ public void convertToServerResponse_withTurnCompleteFalse_mapsPartialTrue() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.partial()).hasValue(true); assertThat(response.turnComplete()).hasValue(false); } @@ -128,8 +156,13 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { LiveServerMessage message = LiveServerMessage.builder().toolCall(toolCall).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isPresent(); assertThat(response.content().get().parts().get()).hasSize(1); @@ -139,40 +172,123 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { } @Test - public void convertToServerResponse_withUsageMetadata_returnsEmptyOptional() { + public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageMetadata() { LiveServerMessage message = - LiveServerMessage.builder().usageMetadata(UsageMetadata.builder().build()).build(); + LiveServerMessage.builder() + .usageMetadata( + UsageMetadata.builder() + .promptTokenCount(10) + .responseTokenCount(20) + .totalTokenCount(30) + .build()) + .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); + assertThat(response.usageMetadata()).isPresent(); + GenerateContentResponseUsageMetadata expectedUsageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + assertThat(response.usageMetadata()).hasValue(expectedUsageMetadata); } @Test - public void convertToServerResponse_withToolCallCancellation_returnsEmptyOptional() { + public void convertToServerResponse_withToolCallCancellation_returnsNoValues() { LiveServerMessage message = LiveServerMessage.builder() .toolCallCancellation(LiveServerToolCallCancellation.builder().build()) .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertNoValues(); + testObserver.assertComplete(); } @Test - public void convertToServerResponse_withSetupComplete_returnsEmptyOptional() { + public void convertToServerResponse_withSetupComplete_returnsNoValues() { LiveServerMessage message = LiveServerMessage.builder() .setupComplete(LiveServerSetupComplete.builder().build()) .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertNoValues(); + testObserver.assertComplete(); } @Test public void convertToServerResponse_withUnknownMessage_returnsErrorResponse() { LiveServerMessage message = LiveServerMessage.builder().build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.errorCode()).isPresent(); assertThat(response.errorMessage()).hasValue("Received unknown server message."); } + + @Test + public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple() { + LiveServerContent serverContent = + LiveServerContent.builder() + .modelTurn(Content.fromParts(Part.fromText("Model response"))) + .turnComplete(true) + .build(); + + UsageMetadata usageMetadata = + UsageMetadata.builder() + .promptTokenCount(10) + .responseTokenCount(20) + .totalTokenCount(30) + .build(); + + LiveServerMessage message = + LiveServerMessage.builder() + .serverContent(serverContent) + .usageMetadata(usageMetadata) + .build(); + + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(2); + testObserver.assertComplete(); + + List responses = testObserver.values(); + + // Check for ServerContent response + LlmResponse contentResponse = responses.get(0); + assertThat(contentResponse.content()).isPresent(); + assertThat(contentResponse.content().get().text()).isEqualTo("Model response"); + assertThat(contentResponse.usageMetadata()).isEmpty(); + + // Check for UsageMetadata response + LlmResponse usageResponse = responses.get(1); + assertThat(usageResponse.content()).isEmpty(); + assertThat(usageResponse.usageMetadata()).isPresent(); + GenerateContentResponseUsageMetadata expectedUsageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + assertThat(usageResponse.usageMetadata()).hasValue(expectedUsageMetadata); + } } diff --git a/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java new file mode 100644 index 000000000..345314256 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/GlobalInstructionPluginTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.adk.sessions.State; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class GlobalInstructionPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + @Mock private CallbackContext mockCallbackContext; + @Mock private InvocationContext mockInvocationContext; + private final State state = new State(new ConcurrentHashMap<>()); + private final Session session = Session.builder("session_id").state(state).build(); + @Mock private BaseArtifactService mockArtifactService; + + @Before + public void setUp() { + state.clear(); + when(mockCallbackContext.invocationId()).thenReturn("invocation_id"); + when(mockCallbackContext.agentName()).thenReturn("agent_name"); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.artifactService()).thenReturn(mockArtifactService); + } + + @Test + public void beforeModelCallback_noExistingInstruction() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("global instruction"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("global instruction"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_withExistingInstruction() { + LlmRequest.Builder llmRequestBuilder = + LlmRequest.builder() + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder().parts(Part.fromText("existing instruction")).build()) + .build()); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("global instruction"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(2); + assertThat(parts.get(0).text()).hasValue("global instruction\n\n"); + assertThat(parts.get(1).text()).hasValue("existing instruction"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_withInstructionProvider() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = + new GlobalInstructionPlugin(unusedContext -> Maybe.just("instruction from provider")); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("instruction from provider"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_withStringInstruction_injectsState() { + state.put("name", "Alice"); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("Hello {name}"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).text()).hasValue("Hello Alice"); + assertThat(systemInstruction.role()).isEmpty(); + } + + @Test + public void beforeModelCallback_nullInstruction() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin((String) null); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + assertThat(llmRequestBuilder.build().config()).isEmpty(); + } + + @Test + public void beforeModelCallback_emptyInstruction() { + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin(""); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + assertThat(llmRequestBuilder.build().config()).isEmpty(); + } + + @Test + public void beforeModelCallback_withExistingInstructionAndRole_preservesRole() { + LlmRequest.Builder llmRequestBuilder = + LlmRequest.builder() + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder() + .parts(Part.fromText("existing instruction")) + .role("system") + .build()) + .build()); + GlobalInstructionPlugin plugin = new GlobalInstructionPlugin("global instruction"); + plugin.beforeModelCallback(mockCallbackContext, llmRequestBuilder).test().assertComplete(); + Content systemInstruction = llmRequestBuilder.build().config().get().systemInstruction().get(); + List parts = systemInstruction.parts().get(); + assertThat(parts).hasSize(2); + assertThat(parts.get(0).text()).hasValue("global instruction\n\n"); + assertThat(parts.get(1).text()).hasValue("existing instruction"); + assertThat(systemInstruction.role()).hasValue("system"); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index ae42bb27e..4ae856fc7 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -333,4 +333,31 @@ public void onToolErrorCallback_singlePlugin() { verify(plugin1).onToolErrorCallback(mockTool, toolArgs, mockToolContext, mockThrowable); } + + @Test + public void close_allComplete() { + when(plugin1.close()).thenReturn(Completable.complete()); + when(plugin2.close()).thenReturn(Completable.complete()); + pluginManager.registerPlugin(plugin1); + pluginManager.registerPlugin(plugin2); + + pluginManager.close().test().assertResult(); + + verify(plugin1).close(); + verify(plugin2).close(); + } + + @Test + public void close_plugin1Fails() { + RuntimeException testException = new RuntimeException("Test"); + when(plugin1.close()).thenReturn(Completable.error(testException)); + when(plugin2.close()).thenReturn(Completable.complete()); + pluginManager.registerPlugin(plugin1); + pluginManager.registerPlugin(plugin2); + + pluginManager.close().test().assertError(testException); + + verify(plugin1).close(); + verify(plugin2).close(); + } } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index b0dbedcd6..421b79abb 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -36,7 +36,6 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; -import com.google.adk.apps.ResumabilityConfig; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -179,6 +178,70 @@ public void eventsCompaction_enabled() { "user: summary 2"); } + @Test + public void eventsCompaction_withNullOverlap_doesNotCompact() { + TestLlm testLlm = + createTestLlm( + createLlmResponse(createContent("llm 1")), createLlmResponse(createContent("llm 2"))); + LlmAgent agent = createTestAgent(testLlm); + + Runner runner = + Runner.builder() + .app( + App.builder() + .name(this.runner.appName()) + .rootAgent(agent) + .eventsCompactionConfig(new EventsCompactionConfig(1, null, null, null, null)) + .build()) + .sessionService(this.runner.sessionService()) + .build(); + + var unused1 = + runner.runAsync("user", session.id(), createContent("user 1")).toList().blockingGet(); + var unused2 = + runner.runAsync("user", session.id(), createContent("user 2")).toList().blockingGet(); + + Session updatedSession = + runner + .sessionService() + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(simplifyEvents(updatedSession.events())) + .containsExactly("user: user 1", "test agent: llm 1", "user: user 2", "test agent: llm 2"); + } + + @Test + public void eventsCompaction_withNullInterval_doesNotCompact() { + TestLlm testLlm = + createTestLlm( + createLlmResponse(createContent("llm 1")), createLlmResponse(createContent("llm 2"))); + LlmAgent agent = createTestAgent(testLlm); + + Runner runner = + Runner.builder() + .app( + App.builder() + .name(this.runner.appName()) + .rootAgent(agent) + .eventsCompactionConfig(new EventsCompactionConfig(null, 0, null, null, null)) + .build()) + .sessionService(this.runner.sessionService()) + .build(); + + var unused1 = + runner.runAsync("user", session.id(), createContent("user 1")).toList().blockingGet(); + var unused2 = + runner.runAsync("user", session.id(), createContent("user 2")).toList().blockingGet(); + + Session updatedSession = + runner + .sessionService() + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(simplifyEvents(updatedSession.events())) + .containsExactly("user: user 1", "test agent: llm 1", "user: user 2", "test agent: llm 2"); + } + @Test public void pluginDoesNothing() { var events = @@ -865,48 +928,6 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } - @Test - public void resumabilityConfig_isResumable_isTrueInInvocationContext() { - ArgumentCaptor contextCaptor = - ArgumentCaptor.forClass(InvocationContext.class); - when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); - Runner runner = - Runner.builder() - .app( - App.builder() - .name("test") - .rootAgent(agent) - .plugins(ImmutableList.of(plugin)) - .resumabilityConfig(new ResumabilityConfig(true)) - .build()) - .build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - var unused = - runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(contextCaptor.getValue().isResumable()).isTrue(); - } - - @Test - public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { - ArgumentCaptor contextCaptor = - ArgumentCaptor.forClass(InvocationContext.class); - when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); - Runner runner = - Runner.builder() - .app( - App.builder() - .name("test") - .rootAgent(agent) - .plugins(ImmutableList.of(plugin)) - .resumabilityConfig(new ResumabilityConfig(false)) - .build()) - .build(); - Session session = runner.sessionService().createSession("test", "user").blockingGet(); - var unused = - runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); - assertThat(contextCaptor.getValue().isResumable()).isFalse(); - } - @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -1044,6 +1065,26 @@ public void runAsync_withToolConfirmation() { .inOrder(); } + @Test + public void close_closesPluginsAndCodeExecutors() { + BasePlugin plugin = mockPlugin("close_test_plugin"); + when(plugin.close()).thenReturn(Completable.complete()); + LlmAgent agentWithCodeExecutor = createTestAgentBuilder(testLlm).build(); + Runner runner = + Runner.builder() + .app( + App.builder() + .name("test") + .rootAgent(agentWithCodeExecutor) + .plugins(ImmutableList.of(plugin)) + .build()) + .build(); + + runner.close().blockingAwait(); + + verify(plugin).close(); + } + public static class Tools { private Tools() {} diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 4c35f5b90..6223dd2f0 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -91,6 +91,7 @@ public void lifecycle_listSessions() { stateDelta.put("sessionKey", "sessionValue"); stateDelta.put("_app_appKey", "appValue"); stateDelta.put("_user_userKey", "userValue"); + stateDelta.put("temp:tempKey", "tempValue"); Event event = Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); @@ -107,6 +108,7 @@ public void lifecycle_listSessions() { assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue"); assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue"); assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(listedSession.state()).containsEntry("temp:tempKey", "tempValue"); } @Test @@ -136,6 +138,7 @@ public void appendEvent_updatesSessionState() { stateDelta.put("sessionKey", "sessionValue"); stateDelta.put("_app_appKey", "appValue"); stateDelta.put("_user_userKey", "userValue"); + stateDelta.put("temp:tempKey", "tempValue"); Event event = Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); @@ -147,6 +150,7 @@ public void appendEvent_updatesSessionState() { assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); assertThat(session.state()).containsEntry("_app_appKey", "appValue"); assertThat(session.state()).containsEntry("_user_userKey", "userValue"); + assertThat(session.state()).containsEntry("temp:tempKey", "tempValue"); // getSession should return session with merged state. Session retrievedSession = @@ -156,5 +160,96 @@ public void appendEvent_updatesSessionState() { assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue"); assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue"); assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSession.state()).containsEntry("temp:tempKey", "tempValue"); + } + + @Test + public void appendEvent_removesState() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + stateDeltaAdd.put("sessionKey", "sessionValue"); + stateDeltaAdd.put("_app_appKey", "appValue"); + stateDeltaAdd.put("_user_userKey", "userValue"); + stateDeltaAdd.put("temp:tempKey", "tempValue"); + + Event eventAdd = + Event.builder().actions(EventActions.builder().stateDelta(stateDeltaAdd).build()).build(); + + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify state is added + Session retrievedSessionAdd = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue"); + + // Prepare and append event to remove state + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED); + stateDeltaRemove.put("_app_appKey", State.REMOVED); + stateDeltaRemove.put("_user_userKey", State.REMOVED); + stateDeltaRemove.put("temp:tempKey", State.REMOVED); + + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + unused = sessionService.appendEvent(session, eventRemove).blockingGet(); + + // Verify state is removed + Session retrievedSessionRemove = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); + } + + @Test + public void sequentialAgents_shareTempState() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + // Agent 1 writes to temp state + ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); + stateDelta1.put("temp:agent1_output", "data"); + Event event1 = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build(); + var unused = sessionService.appendEvent(session, event1).blockingGet(); + + // Verify agent 1 output is in session state + assertThat(session.state()).containsEntry("temp:agent1_output", "data"); + + // Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes + // "agent1_output" + ConcurrentMap stateDelta2 = new ConcurrentHashMap<>(); + stateDelta2.put("temp:agent2_output", "processed_data"); + stateDelta2.put("temp:agent1_output", State.REMOVED); + Event event2 = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build(); + unused = sessionService.appendEvent(session, event2).blockingGet(); + + // Verify final state after agent 2 processing + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output"); + assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data"); } } diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 5e8f3d992..111b1dce3 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -8,6 +8,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -164,6 +165,18 @@ private ApiResponse handleAppendEvent(String path, InvocationOnMock invocation) eventsData.add(newEventData); eventMap.put(sessionId, mapper.writeValueAsString(eventsData)); + + // Apply stateDelta to session state + extractObjectMap(newEventData, "actions") + .flatMap(actions -> extractObjectMap(actions, "stateDelta")) + .ifPresent( + stateDelta -> { + try { + applyStateDelta(sessionId, stateDelta); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); } catch (Exception e) { throw new RuntimeException(e); } @@ -213,4 +226,30 @@ private ApiResponse handleDeleteSession(String path) { sessionMap.remove(sessionIdToDelete); return responseWithBody(""); } + + private void applyStateDelta(String sessionId, Map stateDelta) throws Exception { + String sessionDataString = sessionMap.get(sessionId); + if (sessionDataString == null) { + return; + } + Map sessionData = + mapper.readValue(sessionDataString, new TypeReference>() {}); + Map sessionState = + extractObjectMap(sessionData, "sessionState").map(HashMap::new).orElseGet(HashMap::new); + + for (Map.Entry entry : stateDelta.entrySet()) { + if (entry.getValue() == null) { + sessionState.remove(entry.getKey()); + } else { + sessionState.put(entry.getKey(), entry.getValue()); + } + } + sessionData.put("sessionState", sessionState); + sessionMap.put(sessionId, mapper.writeValueAsString(sessionData)); + } + + @SuppressWarnings("unchecked") // Safe because map values are Maps read from JSON. + private Optional> extractObjectMap(Map map, String key) { + return Optional.ofNullable((Map) map.get(key)); + } } diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index b77d6f267..f6120cf08 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -1,6 +1,7 @@ package com.google.adk.sessions; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -8,9 +9,14 @@ import com.google.adk.JsonBaseModel; import com.google.adk.events.Event; import com.google.adk.events.EventActions; +import com.google.adk.events.ToolConfirmation; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import java.time.Instant; import java.util.Collections; @@ -18,6 +24,7 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -32,9 +39,7 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio EventActions.builder() .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key", "value"))) - .artifactDelta( - new ConcurrentHashMap<>( - ImmutableMap.of("artifact", Part.fromText("artifact_text")))) + .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact", 1))) .transferToAgent("agent") .escalate(true) .build(); @@ -73,8 +78,7 @@ public void convertEventToJson_fullEvent_success() throws JsonProcessingExceptio JsonNode actionsNode = jsonNode.get("actions"); assertThat(actionsNode.get("skipSummarization").asBoolean()).isTrue(); assertThat(actionsNode.get("stateDelta").get("key").asText()).isEqualTo("value"); - assertThat(actionsNode.get("artifactDelta").get("artifact").get("text").asText()) - .isEqualTo("artifact_text"); + assertThat(actionsNode.get("artifactDelta").get("artifact").asInt()).isEqualTo(1); assertThat(actionsNode.get("transferAgent").asText()).isEqualTo("agent"); assertThat(actionsNode.get("escalate").asBoolean()).isTrue(); } @@ -124,8 +128,7 @@ public void fromApiEvent_fullEvent_success() { Map actions = new HashMap<>(); actions.put("skipSummarization", true); actions.put("stateDelta", ImmutableMap.of("key", "value")); - actions.put( - "artifactDelta", ImmutableMap.of("artifact", ImmutableMap.of("text", "artifact_text"))); + actions.put("artifactDelta", ImmutableMap.of("artifact", 1)); actions.put("transferAgent", "agent"); actions.put("escalate", true); apiEvent.put("actions", actions); @@ -147,11 +150,117 @@ public void fromApiEvent_fullEvent_success() { EventActions eventActions = event.actions(); assertThat(eventActions.skipSummarization()).hasValue(true); assertThat(eventActions.stateDelta()).containsEntry("key", "value"); - assertThat(eventActions.artifactDelta().get("artifact").text()).hasValue("artifact_text"); + assertThat(eventActions.artifactDelta()).containsEntry("artifact", 1); assertThat(eventActions.transferToAgent()).hasValue("agent"); assertThat(eventActions.escalate()).hasValue(true); } + @Test + public void fromApiEvent_withTransferToAgent_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map actions = new HashMap<>(); + actions.put("transferToAgent", "agent-id"); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.actions().transferToAgent()).hasValue("agent-id"); + } + + @Test + public void convertEventToJson_complexActions_success() throws JsonProcessingException { + ConcurrentMap> authConfigs = new ConcurrentHashMap<>(); + authConfigs.put("auth1", new ConcurrentHashMap<>(ImmutableMap.of("param1", "value1"))); + + ConcurrentMap toolConfirmations = new ConcurrentHashMap<>(); + toolConfirmations.put( + "tool1", ToolConfirmation.builder().hint("hint1").confirmed(true).build()); + + EventActions actions = + EventActions.builder() + .requestedAuthConfigs(authConfigs) + .requestedToolConfirmations(toolConfirmations) + .endOfAgent(true) + .build(); + + GenerateContentResponseUsageMetadata usageMetadata = + GenerateContentResponseUsageMetadata.builder().promptTokenCount(10).build(); + GroundingMetadata groundingMetadata = GroundingMetadata.builder().build(); + + Event event = + Event.builder() + .author("user") + .invocationId("inv-123") + .timestamp(Instant.parse("2023-01-01T00:00:00.123Z").toEpochMilli()) + .actions(actions) + .longRunningToolIds(ImmutableSet.of("tool-id-1")) + .usageMetadata(usageMetadata) + .groundingMetadata(groundingMetadata) + .build(); + + String json = SessionJsonConverter.convertEventToJson(event, true); + JsonNode jsonNode = objectMapper.readTree(json); + + assertThat(jsonNode.get("timestamp").asText()).isEqualTo("2023-01-01T00:00:00.123Z"); + + JsonNode eventMetadata = jsonNode.get("eventMetadata"); + assertThat(eventMetadata.get("longRunningToolIds").get(0).asText()).isEqualTo("tool-id-1"); + assertThat(eventMetadata.has("usageMetadata")).isTrue(); + assertThat(eventMetadata.has("groundingMetadata")).isTrue(); + + JsonNode actionsNode = jsonNode.get("actions"); + assertThat(actionsNode.get("requestedAuthConfigs").get("auth1").get("param1").asText()) + .isEqualTo("value1"); + assertThat(actionsNode.get("requestedToolConfirmations").get("tool1").get("hint").asText()) + .isEqualTo("hint1"); + assertThat( + actionsNode.get("requestedToolConfirmations").get("tool1").get("confirmed").asBoolean()) + .isTrue(); + assertThat(actionsNode.get("endOfAgent").asBoolean()).isTrue(); + } + + @Test + public void fromApiEvent_complexActions_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00.123Z"); + + Map actions = new HashMap<>(); + actions.put("requestedAuthConfigs", ImmutableMap.of("auth1", ImmutableMap.of("p1", "v1"))); + actions.put( + "requestedToolConfirmations", + ImmutableMap.of("tool1", ImmutableMap.of("hint", "h1", "confirmed", true))); + actions.put("endOfAgent", true); + apiEvent.put("actions", actions); + + Map eventMetadata = new HashMap<>(); + eventMetadata.put("longRunningToolIds", ImmutableList.of("tool-1")); + eventMetadata.put("usageMetadata", ImmutableMap.of("promptTokenCount", 10)); + eventMetadata.put("groundingMetadata", ImmutableMap.of()); + apiEvent.put("eventMetadata", eventMetadata); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.timestamp()) + .isEqualTo(Instant.parse("2023-01-01T00:00:00.123Z").toEpochMilli()); + assertThat(event.longRunningToolIds().get()).containsExactly("tool-1"); + assertThat(event.usageMetadata().get().promptTokenCount()).hasValue(10); + assertThat(event.groundingMetadata()).isPresent(); + + EventActions eventActions = event.actions(); + assertThat(eventActions.requestedAuthConfigs().get("auth1")).containsEntry("p1", "v1"); + assertThat(eventActions.requestedToolConfirmations().get("tool1").hint()).isEqualTo("h1"); + assertThat(eventActions.requestedToolConfirmations().get("tool1").confirmed()).isTrue(); + assertThat(eventActions.endOfAgent()).isTrue(); + } + @Test public void fromApiEvent_minimalEvent_success() { Map apiEvent = new HashMap<>(); @@ -221,4 +330,98 @@ public void fromApiEvent_missingMetadataFields_success() { assertThat(event.turnComplete().get()).isFalse(); assertThat(event.interrupted().get()).isFalse(); } + + @Test + public void convertEventToJson_withStateRemoved_success() throws JsonProcessingException { + EventActions actions = + EventActions.builder() + .stateDelta( + new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", State.REMOVED))) + .build(); + + Event event = + Event.builder() + .author("user") + .invocationId("inv-123") + .timestamp(Instant.parse("2023-01-01T00:00:00Z").toEpochMilli()) + .actions(actions) + .build(); + + String json = SessionJsonConverter.convertEventToJson(event); + JsonNode jsonNode = objectMapper.readTree(json); + + JsonNode actionsNode = jsonNode.get("actions"); + assertThat(actionsNode.get("stateDelta").get("key1").asText()).isEqualTo("value1"); + assertThat(actionsNode.get("stateDelta").get("key2").isNull()).isTrue(); + } + + @Test + public void fromApiEvent_withInvalidContentMap_returnsNullContent() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + // Parts should be a list, not a string + apiEvent.put("content", ImmutableMap.of("parts", "invalid")); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.content()).isEmpty(); + } + + @Test + public void fromApiEvent_withInvalidArtifactDelta_skipsInvalidEntries() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map artifactDelta = new HashMap<>(); + artifactDelta.put("valid", 1); + artifactDelta.put("invalid", "not-a-map"); + + Map actions = new HashMap<>(); + actions.put("artifactDelta", artifactDelta); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + assertThat(event.actions().artifactDelta()).containsKey("valid"); + assertThat(event.actions().artifactDelta()).doesNotContainKey("invalid"); + } + + @Test + public void fromApiEvent_missingTimestamp_throwsException() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + + assertThrows(IllegalArgumentException.class, () -> SessionJsonConverter.fromApiEvent(apiEvent)); + } + + @Test + public void fromApiEvent_withNullStateDeltaValue_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map stateDelta = new HashMap<>(); + stateDelta.put("key1", "value1"); + stateDelta.put("key2", null); + + Map actions = new HashMap<>(); + actions.put("stateDelta", stateDelta); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + EventActions eventActions = event.actions(); + assertThat(eventActions.stateDelta()).containsEntry("key1", "value1"); + assertThat(eventActions.stateDelta()).containsEntry("key2", State.REMOVED); + } } diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 775b465ff..36eab1d16 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -337,4 +337,32 @@ public void listEmptySession_success() { .events()) .isEmpty(); } + + @Test + public void appendEvent_withStateRemoved_updatesSessionState() { + String userId = "userB"; + ConcurrentMap initialState = + new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + Session session = + vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet(); + + ConcurrentMap stateDelta = + new ConcurrentHashMap<>(ImmutableMap.of("key2", State.REMOVED)); + Event event = + Event.builder() + .invocationId("456") + .author(userId) + .timestamp(Instant.parse("2024-12-12T12:12:12.123456Z").toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + var unused = vertexAiSessionService.appendEvent(session, event).blockingGet(); + + Session updatedSession = + vertexAiSessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + + assertThat(updatedSession.state()).containsExactly("key1", "value1"); + assertThat(updatedSession.state()).doesNotContainKey("key2"); + } } diff --git a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java new file mode 100644 index 000000000..7a4a3ddb9 --- /dev/null +++ b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java @@ -0,0 +1,368 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.summarizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.adk.events.EventCompaction; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.List; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class TailRetentionEventCompactorTest { + + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + @Mock private BaseSessionService mockSessionService; + @Mock private BaseEventSummarizer mockSummarizer; + @Captor private ArgumentCaptor> eventListCaptor; + + @Test + public void constructor_negativeTokenThreshold_throwsException() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> new TailRetentionEventCompactor(mockSummarizer, 2, -1))) + .hasMessageThat() + .contains("tokenThreshold must be non-negative"); + } + + @Test + public void constructor_negativeRetentionSize_throwsException() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> new TailRetentionEventCompactor(mockSummarizer, -1, 100))) + .hasMessageThat() + .contains("retentionSize must be non-negative"); + } + + @Test + public void compaction_skippedWhenEstimatedTokenUsageBelowThreshold() { + // Threshold is 100. + // Event1: "Event1" -> length 6. + // Retain1: "Retain1" -> length 7. + // Retain2: "Retain2" -> length 7. + // Total length = 20. Estimated tokens = 20 / 4 = 5. + // 5 <= 100 -> Skip. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + createEvent(3, "Retain2")); // No usage metadata + Session session = Session.builder("id").events(events).build(); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compaction_happensWhenEstimatedTokenUsageAboveThreshold() { + // Threshold is 2. + // Event1: "Event1" -> length 6. + // Retain1: "Retain1" -> length 7. + // Retain2: "Retain2" -> length 7. + // Total eligible for estimation (including retained ones as per current logic): + // Logic: getCompactionEvents returns [Event1, Retain1, Retain2] for estimation. + // Total length = 20. Estimated tokens = 20 / 4 = 5. + // 5 > 2 -> Compact. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 2); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + createEvent(3, "Retain2")); // No usage metadata + Session session = Session.builder("id").events(events).build(); + Event summaryEvent = createEvent(4, "Summary"); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer).summarizeEvents(any()); + verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent)); + } + + @Test + public void compaction_skippedWhenTokenUsageBelowThreshold() { + // Threshold is 300, usage is 200. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + withUsage(createEvent(3, "Retain2"), 200)); + Session session = Session.builder("id").events(events).build(); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compaction_happensWhenTokenUsageAboveThreshold() { + // Threshold is 300, usage is 400. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 300); + Event event3 = withUsage(createEvent(3, "Retain2"), 400); + ImmutableList events = + ImmutableList.of(createEvent(1, "Event1"), createEvent(2, "Retain1"), event3); + Session session = Session.builder("id").events(events).build(); + Event summaryEvent = createEvent(4, "Summary"); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer).summarizeEvents(any()); + verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent)); + } + + @Test + public void compact_notEnoughEvents_doesNothing() { + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Event2"), + withUsage(createEvent(3, "Event3"), 200)); + Session session = Session.builder("id").events(events).build(); + + // Retention size 5 > 3 events. Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 5, 100); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer, never()).summarizeEvents(any()); + verify(mockSessionService, never()).appendEvent(any(), any()); + } + + @Test + public void compact_respectRetentionSize_summarizesCorrectEvents() { + // Retention size is 2. + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + withUsage(createEvent(3, "Retain2"), 200)); + Session session = Session.builder("id").events(events).build(); + Event compactedEvent = createCompactedEvent(1, 1, "Summary", 4); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); + when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); + + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer).summarizeEvents(eventListCaptor.capture()); + List summarizedEvents = eventListCaptor.getValue(); + assertThat(summarizedEvents).hasSize(1); + assertThat(getPromptText(summarizedEvents.get(0))).isEqualTo("Event1"); + + verify(mockSessionService).appendEvent(eq(session), eq(compactedEvent)); + } + + @Test + public void compact_withRetainedEventsPhysicallyBeforeCompaction_includesThem() { + // Simulating the user's specific case with retention size 1: + // "event1, event2, event3, compaction1-2 ... event3 is retained so it is before compaction + // event" + // + // Timeline: + // T=1: E1 + // T=2: E2 + // T=3: E3 + // T=4: C1 (Covers T=1 to T=2). + // + // Note: C1 was inserted *after* E3 in the list. + // List order: E1, E2, E3, C1. + // + // If we have more events: + // T=5: E5 + // T=6: E6 + // + // Retained: E6. + // Summary Input: C1, E3, E5. (E1, E2 covered by C1). + ImmutableList events = + ImmutableList.of( + createEvent(1, "E1"), + createEvent(2, "E2"), + createEvent(3, "E3"), + createCompactedEvent( + /* startTimestamp= */ 1, /* endTimestamp= */ 2, "C1", /* eventTimestamp= */ 4), + createEvent(5, "E5"), + withUsage(createEvent(6, "E6"), 200)); + Session session = Session.builder("id").events(events).build(); + Event compactedEvent = createCompactedEvent(1, 5, "Summary C1-E5", 7); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); + when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); + + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 1, 100); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer).summarizeEvents(eventListCaptor.capture()); + List summarizedEvents = eventListCaptor.getValue(); + assertThat(summarizedEvents).hasSize(3); + + // Check first event is reconstructed C1 + Event reconstructedC1 = summarizedEvents.get(0); + assertThat(getPromptText(reconstructedC1)).isEqualTo("C1"); + // Verify timestamp is reset to startTimestamp (1) + assertThat(reconstructedC1.timestamp()).isEqualTo(1); + + // Check second event is E3 + Event e3 = summarizedEvents.get(1); + assertThat(getPromptText(e3)).isEqualTo("E3"); + assertThat(e3.timestamp()).isEqualTo(3); + + // Check third event is E5 + Event e5 = summarizedEvents.get(2); + assertThat(getPromptText(e5)).isEqualTo("E5"); + assertThat(e5.timestamp()).isEqualTo(5); + } + + @Test + public void compact_withMultipleCompactionEvents_respectsCompactionBoundary() { + // T=1: E1 + // T=2: E2, retained by C1 + // T=3: E3, retained by C1 + // T=4: E4, retained by C1 and C2 + // T=5: C1 (Covers T=1) + // T=6: E6, retained by C2 + // T=7: E7, retained by C2 + // T=8: C2 (Covers T=1 to T=3) since it covers C1 which starts at T=1. + // T=9: E9 + + // Retention = 3. + // Expected to summarize: C2, E4. (E1 covered by C1 - ignored, E2, E3 covered by C2). + // E6, E7, E9 are retained. + + ImmutableList events = + ImmutableList.of( + createEvent(1, "E1"), + createEvent(2, "E2"), + createEvent(3, "E3"), + createEvent(4, "E4"), + createCompactedEvent( + /* startTimestamp= */ 1, /* endTimestamp= */ 1, "C1", /* eventTimestamp= */ 5), + createEvent(6, "E6"), + createEvent(7, "E7"), + createCompactedEvent( + /* startTimestamp= */ 1, /* endTimestamp= */ 3, "C2", /* eventTimestamp= */ 8), + withUsage(createEvent(9, "E9"), 200)); + Session session = Session.builder("id").events(events).build(); + Event compactedEvent = createCompactedEvent(1, 4, "Summary C2-E4", 10); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(compactedEvent)); + when(mockSessionService.appendEvent(any(), any())).then(i -> Single.just(i.getArgument(1))); + + // Token usage 200 > threshold 100. + TailRetentionEventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 3, 100); + + compactor.compact(session, mockSessionService).test().assertComplete(); + + verify(mockSummarizer).summarizeEvents(eventListCaptor.capture()); + List summarizedEvents = eventListCaptor.getValue(); + + assertThat(summarizedEvents).hasSize(2); + + // Check first event is reconstructed C2 + Event reconstructedC2 = summarizedEvents.get(0); + assertThat(getPromptText(reconstructedC2)).isEqualTo("C2"); + // Verify timestamp is reset to startTimestamp (1), not event timestamp (8) or end timestamp (3) + assertThat(reconstructedC2.timestamp()).isEqualTo(1); + + // Check second event is E4 + Event e4 = summarizedEvents.get(1); + assertThat(e4.timestamp()).isEqualTo(4); + } + + private static Event createEvent(long timestamp, String text) { + return Event.builder() + .timestamp(timestamp) + .content(Content.builder().parts(Part.fromText(text)).build()) + .build(); + } + + private static String getPromptText(Event event) { + return event + .content() + .flatMap(Content::parts) + .flatMap(parts -> parts.stream().findFirst()) + .flatMap(Part::text) + .orElseThrow(); + } + + private Event withUsage(Event event, int tokens) { + return event.toBuilder() + .usageMetadata( + GenerateContentResponseUsageMetadata.builder().promptTokenCount(tokens).build()) + .build(); + } + + private Event createCompactedEvent( + long startTimestamp, long endTimestamp, String content, long eventTimestamp) { + return Event.builder() + .timestamp(eventTimestamp) + .actions( + EventActions.builder() + .compaction( + EventCompaction.builder() + .startTimestamp(startTimestamp) + .endTimestamp(endTimestamp) + .compactedContent( + Content.builder() + .role("model") + .parts(Part.builder().text(content).build()) + .build()) + .build()) + .build()) + .build(); + } +} diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 438522493..9439fe718 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,22 +16,47 @@ package com.google.adk.telemetry; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; - -import io.opentelemetry.api.GlobalOpenTelemetry; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LiveRequestQueue; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; -import io.opentelemetry.sdk.OpenTelemetrySdk; -import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter; -import io.opentelemetry.sdk.trace.SdkTracerProvider; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; -import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import java.util.Optional; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; /** * Tests for OpenTelemetry context propagation in ADK. @@ -39,31 +64,32 @@ *

Verifies that spans created by ADK properly link to parent contexts when available, enabling * proper distributed tracing across async boundaries. */ -class ContextPropagationTest { +@RunWith(JUnit4.class) +public class ContextPropagationTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); - private InMemorySpanExporter spanExporter; private Tracer tracer; + private Tracer originalTracer; + private LlmAgent agent; + private InMemorySessionService sessionService; + + @Before + public void setup() { + this.originalTracer = Tracing.getTracer(); + Tracing.setTracerForTesting( + openTelemetryRule.getOpenTelemetry().getTracer("ContextPropagationTest")); + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test"); + agent = LlmAgent.builder().name("test_agent").description("test-description").build(); + sessionService = new InMemorySessionService(); + } - @BeforeEach - void setup() { - // Reset GlobalOpenTelemetry state - GlobalOpenTelemetry.resetForTest(); - - spanExporter = InMemorySpanExporter.create(); - - SdkTracerProvider tracerProvider = - SdkTracerProvider.builder() - .addSpanProcessor(SimpleSpanProcessor.create(spanExporter)) - .build(); - - OpenTelemetrySdk sdk = - OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).buildAndRegisterGlobal(); - - tracer = sdk.getTracer("test"); + @After + public void tearDown() { + Tracing.setTracerForTesting(originalTracer); } @Test - void testToolCallSpanLinksToParent() { + public void testToolCallSpanLinksToParent() { // Given: Parent span is active Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -82,58 +108,48 @@ void testToolCallSpanLinksToParent() { } // Then: tool_call should be child of parent - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: parent and tool_call"); - - SpanData parentSpanData = - spans.stream() - .filter(s -> s.getName().equals("parent")) - .findFirst() - .orElseThrow(() -> new AssertionError("Parent span not found")); - - SpanData toolCallSpanData = - spans.stream() - .filter(s -> s.getName().equals("tool_call [testTool]")) - .findFirst() - .orElseThrow(() -> new AssertionError("Tool call span not found")); + SpanData parentSpanData = findSpanByName("parent"); + SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); // Verify parent-child relationship assertEquals( + "Tool call should have same trace ID as parent", parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId(), - "Tool call should have same trace ID as parent"); + toolCallSpanData.getSpanContext().getTraceId()); assertEquals( + "Tool call's parent should be the parent span", parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId(), - "Tool call's parent should be the parent span"); + toolCallSpanData.getParentSpanContext().getSpanId()); } @Test - void testToolCallWithoutParentCreatesRootSpan() { + public void testToolCallWithoutParentCreatesRootSpan() { // Given: No parent span active // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); + try (Scope s = Context.root().makeCurrent()) { + Span toolCallSpan = + tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); + try (Scope scope = toolCallSpan.makeCurrent()) { + // Work + } finally { + toolCallSpan.end(); + } } // Then: Should create root span (backward compatible) - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(1, spans.size(), "Should have exactly 1 span"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have exactly 1 span", 1, spans.size()); SpanData toolCallSpanData = spans.get(0); assertFalse( - toolCallSpanData.getParentSpanContext().isValid(), - "Tool call should be root span when no parent exists"); + "Tool call should be root span when no parent exists", + toolCallSpanData.getParentSpanContext().isValid()); } @Test - void testNestedSpanHierarchy() { + public void testNestedSpanHierarchy() { // Test: parent → invocation → tool_call → tool_response hierarchy Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -168,51 +184,46 @@ void testNestedSpanHierarchy() { } // Verify complete hierarchy - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(4, spans.size(), "Should have 4 spans in the hierarchy"); + List spans = openTelemetryRule.getSpans(); + // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response + // [testTool]". + assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); - String parentTraceId = - spans.stream() - .filter(s -> s.getName().equals("parent")) - .findFirst() - .map(s -> s.getSpanContext().getTraceId()) - .orElseThrow(() -> new AssertionError("Parent span not found")); + SpanData parentSpanData = findSpanByName("parent"); + String parentTraceId = parentSpanData.getSpanContext().getTraceId(); // All spans should have same trace ID - spans.forEach( - span -> - assertEquals( - parentTraceId, - span.getSpanContext().getTraceId(), - "All spans should be in same trace")); + for (SpanData span : openTelemetryRule.getSpans()) { + assertEquals( + "All spans should be in same trace", parentTraceId, span.getSpanContext().getTraceId()); + } // Verify parent-child relationships - SpanData parentSpanData = findSpanByName(spans, "parent"); - SpanData invocationSpanData = findSpanByName(spans, "invocation"); - SpanData toolCallSpanData = findSpanByName(spans, "tool_call [testTool]"); - SpanData toolResponseSpanData = findSpanByName(spans, "tool_response [testTool]"); + SpanData invocationSpanData = findSpanByName("invocation"); + SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); + SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); // invocation should be child of parent assertEquals( + "Invocation should be child of parent", parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId(), - "Invocation should be child of parent"); + invocationSpanData.getParentSpanContext().getSpanId()); // tool_call should be child of invocation assertEquals( + "Tool call should be child of invocation", invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId(), - "Tool call should be child of invocation"); + toolCallSpanData.getParentSpanContext().getSpanId()); // tool_response should be child of tool_call assertEquals( + "Tool response should be child of tool call", toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId(), - "Tool response should be child of tool call"); + toolResponseSpanData.getParentSpanContext().getSpanId()); } @Test - void testMultipleSpansInParallel() { + public void testMultipleSpansInParallel() { // Test: Multiple tool calls in parallel should all link to same parent Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -234,70 +245,66 @@ void testMultipleSpansInParallel() { } // Verify all tool calls link to same parent - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(4, spans.size(), "Should have 4 spans: 1 parent + 3 tool calls"); - - SpanData parentSpanData = findSpanByName(spans, "parent"); + SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); String parentSpanId = parentSpanData.getSpanContext().getSpanId(); // All tool calls should have same trace ID and parent span ID List toolCallSpans = - spans.stream().filter(s -> s.getName().startsWith("tool_call")).toList(); + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().startsWith("tool_call")) + .toList(); - assertEquals(3, toolCallSpans.size(), "Should have 3 tool call spans"); + assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); toolCallSpans.forEach( span -> { assertEquals( + "Tool call should have same trace ID as parent", parentTraceId, - span.getSpanContext().getTraceId(), - "Tool call should have same trace ID as parent"); + span.getSpanContext().getTraceId()); assertEquals( + "Tool call should have parent as parent span", parentSpanId, - span.getParentSpanContext().getSpanId(), - "Tool call should have parent as parent span"); + span.getParentSpanContext().getSpanId()); }); } @Test - void testAgentRunSpanLinksToInvocation() { - // Test: agent_run span should link to invocation span + public void testInvokeAgentSpanLinksToInvocation() { + // Test: invoke_agent span should link to invocation span Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span agentRunSpan = - tracer.spanBuilder("agent_run [test-agent]").setParent(Context.current()).startSpan(); + Span invokeAgentSpan = + tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - try (Scope agentScope = agentRunSpan.makeCurrent()) { + try (Scope agentScope = invokeAgentSpan.makeCurrent()) { // Simulate agent work } finally { - agentRunSpan.end(); + invokeAgentSpan.end(); } } finally { invocationSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: invocation and agent_run"); - - SpanData invocationSpanData = findSpanByName(spans, "invocation"); - SpanData agentRunSpanData = findSpanByName(spans, "agent_run [test-agent]"); + SpanData invocationSpanData = findSpanByName("invocation"); + SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); assertEquals( + "Agent run should be child of invocation", invocationSpanData.getSpanContext().getSpanId(), - agentRunSpanData.getParentSpanContext().getSpanId(), - "Agent run should be child of invocation"); + invokeAgentSpanData.getParentSpanContext().getSpanId()); } @Test - void testCallLlmSpanLinksToAgentRun() { + public void testCallLlmSpanLinksToAgentRun() { // Test: call_llm span should link to agent_run span - Span agentRunSpan = tracer.spanBuilder("agent_run [test-agent]").startSpan(); + Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - try (Scope agentScope = agentRunSpan.makeCurrent()) { + try (Scope agentScope = invokeAgentSpan.makeCurrent()) { Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); try (Scope llmScope = callLlmSpan.makeCurrent()) { @@ -306,49 +313,366 @@ void testCallLlmSpanLinksToAgentRun() { callLlmSpan.end(); } } finally { - agentRunSpan.end(); + invokeAgentSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: agent_run and call_llm"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans", 2, spans.size()); - SpanData agentRunSpanData = findSpanByName(spans, "agent_run [test-agent]"); - SpanData callLlmSpanData = findSpanByName(spans, "call_llm"); + SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); + SpanData callLlmSpanData = findSpanByName("call_llm"); assertEquals( - agentRunSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId(), - "Call LLM should be child of agent run"); + "Call LLM should be child of agent run", + invokeAgentSpanData.getSpanContext().getSpanId(), + callLlmSpanData.getParentSpanContext().getSpanId()); } @Test - void testSpanCreatedWithinParentScopeIsCorrectlyParented() { + public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { // Test: Simulates creating a span within the scope of a parent Span parentSpan = tracer.spanBuilder("invocation").startSpan(); try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("agent_run").setParent(Context.current()).startSpan(); + Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); agentSpan.end(); } finally { parentSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans"); + SpanData parentSpanData = findSpanByName("invocation"); + SpanData agentSpanData = findSpanByName("invoke_agent"); + + assertEquals( + "Agent span should be a child of the invocation span", + parentSpanData.getSpanContext().getSpanId(), + agentSpanData.getParentSpanContext().getSpanId()); + } + + @Test + public void testTraceFlowable() throws InterruptedException { + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Span flowableSpan = tracer.spanBuilder("flowable").setParent(Context.current()).startSpan(); + Flowable flowable = + Tracing.traceFlowable( + Context.current().with(flowableSpan), + flowableSpan, + () -> + Flowable.just(1, 2, 3) + .map( + i -> { + assertEquals( + flowableSpan.getSpanContext().getSpanId(), + Span.current().getSpanContext().getSpanId()); + return i * 2; + })); + flowable.test().await().assertComplete(); + } finally { + parentSpan.end(); + } + + SpanData parentSpanData = findSpanByName("parent"); + SpanData flowableSpanData = findSpanByName("flowable"); + assertEquals( + parentSpanData.getSpanContext().getSpanId(), + flowableSpanData.getParentSpanContext().getSpanId()); + assertTrue(flowableSpanData.hasEnded()); + } - SpanData parentSpanData = findSpanByName(spans, "invocation"); - SpanData agentSpanData = findSpanByName(spans, "agent_run"); + @Test + public void testTraceTransformer() throws InterruptedException { + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Flowable flowable = + Flowable.just(1, 2, 3) + .map( + i -> { + assertTrue(Span.current().getSpanContext().isValid()); + return i * 2; + }) + .compose(Tracing.trace("transformer")); + flowable.test().await().assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parentSpanData = findSpanByName("parent"); + SpanData transformerSpanData = findSpanByName("transformer"); assertEquals( parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId(), - "Agent span should be a child of the invocation span"); + transformerSpanData.getParentSpanContext().getSpanId()); + assertTrue(transformerSpanData.hasEnded()); + } + + @Test + public void testTraceAgentInvocation() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceAgentInvocation( + span, "test-agent", "test-description", buildInvocationContext()); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("test-agent", attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))); + assertEquals("test-description", attrs.get(AttributeKey.stringKey("gen_ai.agent.description"))); + assertEquals("test-session", attrs.get(AttributeKey.stringKey("gen_ai.conversation.id"))); + } + + @Test + public void testTraceToolCall() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceToolCall( + "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals( + "{\"arg1\":\"value1\"}", + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_request"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_response"))); + } + + @Test + public void testTraceToolResponse() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Event functionResponseEvent = + Event.builder() + .id("event-1") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("tool-name") + .id("tool-call-id") + .response(ImmutableMap.of("result", "tool-result")) + .build()) + .build())) + .build(); + Tracing.traceToolResponse("event-1", functionResponseEvent); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals( + "{\"result\":\"tool-result\"}", + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); + } + + @Test + public void testTraceCallLlm() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-pro") + .contents(ImmutableList.of(Content.fromParts(Part.fromText("hello")))) + .config(GenerateContentConfig.builder().topP(0.9f).maxOutputTokens(100).build()) + .build(); + LlmResponse llmResponse = + LlmResponse.builder() + .content(Content.builder().parts(Part.fromText("world")).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build()) + .build(); + Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); + assertEquals( + "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("test-session", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertEquals(0.9d, attrs.get(AttributeKey.doubleKey("gen_ai.request.top_p")), 0.01); + assertEquals(100L, (long) attrs.get(AttributeKey.longKey("gen_ai.request.max_tokens"))); + assertEquals(10L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.input_tokens"))); + assertEquals(20L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.output_tokens"))); + assertEquals( + ImmutableList.of("stop"), + attrs.get(AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"))); + assertTrue( + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_request")).contains("gemini-pro")); + assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_response")).contains("STOP")); + } + + @Test + public void testTraceSendData() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceSendData( + buildInvocationContext(), + "event-1", + ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals( + "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("test-session", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.data")).contains("hello")); + } + + // Agent that emits one event on a computation thread. + private static class TestAgent extends BaseAgent { + TestAgent() { + super("test-agent", "test-description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext context) { + return Flowable.just( + Event.builder().content(Content.fromParts(Part.fromText("test"))).build()) + .subscribeOn(Schedulers.computation()); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.just( + Event.builder().content(Content.fromParts(Part.fromText("test"))).build()) + .subscribeOn(Schedulers.computation()); + } + } + + @Test + public void baseAgentRunAsync_propagatesContext() throws InterruptedException { + BaseAgent agent = new TestAgent(); + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + agent.runAsync(buildInvocationContext()).test().await().assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parent = findSpanByName("parent"); + SpanData agentSpan = findSpanByName("invoke_agent test-agent"); + assertEquals(parent.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + } + + @Test + public void runnerRunAsync_propagatesContext() throws InterruptedException { + BaseAgent agent = new TestAgent(); + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Session session = + runner + .sessionService() + .createSession("test-app", "test-user", null, "test-session") + .blockingGet(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + runner + .runAsync(session.userId(), session.id(), newMessage, runConfig, null) + .test() + .await() + .assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parent = findSpanByName("parent"); + SpanData invocation = findSpanByName("invocation"); + SpanData agentSpan = findSpanByName("invoke_agent test-agent"); + assertEquals( + parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); + assertEquals( + invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + } + + @Test + public void runnerRunLive_propagatesContext() throws InterruptedException { + BaseAgent agent = new TestAgent(); + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Session session = + Session.builder("test-session").userId("test-user").appName("test-app").build(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + liveRequestQueue.content(newMessage); + liveRequestQueue.close(); + runner.runLive(session, liveRequestQueue, runConfig).test().await().assertComplete(); + } finally { + parentSpan.end(); + } + SpanData parent = findSpanByName("parent"); + SpanData invocation = findSpanByName("invocation"); + SpanData agentSpan = findSpanByName("invoke_agent test-agent"); + assertEquals( + parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); + assertEquals( + invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + } + + /** + * Finds a span by name, polling multiple times. + * + *

This is necessary because spans might be created in separate threads, and we cannot always + * rely on `.await()` to ensure all spans are available immediately. + */ + private SpanData findSpanByName(String name) { + for (int i = 0; i < 15; i++) { + Optional span = + openTelemetryRule.getSpans().stream().filter(s -> s.getName().equals(name)).findFirst(); + if (span.isPresent()) { + return span.get(); + } + try { + Thread.sleep(10 * i); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + throw new AssertionError("Span not found after polling: " + name); } - private SpanData findSpanByName(List spans, String name) { - return spans.stream() - .filter(s -> s.getName().equals(name)) - .findFirst() - .orElseThrow(() -> new AssertionError("Span not found: " + name)); + private InvocationContext buildInvocationContext() { + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + return InvocationContext.builder() + .sessionService(sessionService) + .session(session) + .agent(agent) + .invocationId("test-invocation-id") + .build(); } } diff --git a/core/src/test/java/com/google/adk/testing/TestBaseAgent.java b/core/src/test/java/com/google/adk/testing/TestBaseAgent.java index e3e5a632c..001993a59 100644 --- a/core/src/test/java/com/google/adk/testing/TestBaseAgent.java +++ b/core/src/test/java/com/google/adk/testing/TestBaseAgent.java @@ -26,7 +26,7 @@ import java.util.function.Supplier; /** A test agent that returns events from a supplier. */ -public final class TestBaseAgent extends BaseAgent { +public class TestBaseAgent extends BaseAgent { private final Supplier> eventSupplier; private int invocationCount = 0; private InvocationContext lastInvocationContext; diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 04f83ed9b..6f35f5a3c 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -102,63 +102,87 @@ public Supplier> asRunAsyncImplSupplier(String contentText) { return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText))); } + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + * with an event containing the given content. + */ + public Supplier> asRunLiveImplSupplier(Content content) { + return () -> + Flowable.defer( + () -> { + markAsCalled(); + return Flowable.just(Event.builder().content(content).build()); + }); + } + + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + */ + public Supplier> asRunLiveImplSupplier(String contentText) { + return asRunLiveImplSupplier(Content.fromParts(Part.fromText(contentText))); + } + @SuppressWarnings("unchecked") // This cast is safe if T is Content. public BeforeAgentCallback asBeforeAgentCallback() { - return ctx -> (Maybe) callMaybe(); + return (unusedCtx) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Content. public BeforeAgentCallbackSync asBeforeAgentCallbackSync() { - return ctx -> (Optional) callOptional(); + return (unusedCtx) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is Content. public AfterAgentCallback asAfterAgentCallback() { - return ctx -> (Maybe) callMaybe(); + return (unusedCtx) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Content. public AfterAgentCallbackSync asAfterAgentCallbackSync() { - return ctx -> (Optional) callOptional(); + return (unusedCtx) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public BeforeModelCallback asBeforeModelCallback() { - return (ctx, req) -> (Maybe) callMaybe(); + return (unusedCtx, unusedReq) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public BeforeModelCallbackSync asBeforeModelCallbackSync() { - return (ctx, req) -> (Optional) callOptional(); + return (unusedCtx, unusedReq) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public AfterModelCallback asAfterModelCallback() { - return (ctx, res) -> (Maybe) callMaybe(); + return (unusedCtx, unusedRes) -> (Maybe) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. public AfterModelCallbackSync asAfterModelCallbackSync() { - return (ctx, res) -> (Optional) callOptional(); + return (unusedCtx, unusedRes) -> (Optional) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public BeforeToolCallback asBeforeToolCallback() { - return (invCtx, tool, toolArgs, toolCtx) -> (Maybe>) callMaybe(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) -> + (Maybe>) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public BeforeToolCallbackSync asBeforeToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx) -> (Optional>) callOptional(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) -> + (Optional>) callOptional(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public AfterToolCallback asAfterToolCallback() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe>) callMaybe(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) -> + (Maybe>) callMaybe(); } @SuppressWarnings("unchecked") // This cast is safe if T is Map. public AfterToolCallbackSync asAfterToolCallbackSync() { - return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional>) callOptional(); + return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) -> + (Optional>) callOptional(); } } diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 2bdcf1fbd..df94b76b2 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -30,7 +30,9 @@ import com.google.adk.events.EventCompaction; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; @@ -68,6 +70,19 @@ public static InvocationContext createInvocationContext(BaseAgent agent) { return createInvocationContext(agent, RunConfig.builder().build()); } + public static InvocationContext createInvocationContext( + BaseAgent agent, BaseSessionService sessionService, Session session) { + return InvocationContext.builder() + .sessionService(sessionService) + .artifactService(new InMemoryArtifactService()) + .invocationId("invocationId") + .agent(agent) + .session(session) + .userContent(Content.fromParts(Part.fromText("user content"))) + .runConfig(RunConfig.builder().build()) + .build(); + } + public static Event createEvent(String id) { return Event.builder() .id(id) diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index d43d9d03a..3a5390027 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -21,11 +21,14 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.utils.ComponentRegistry; @@ -38,6 +41,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.Map; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,6 +50,13 @@ @RunWith(JUnit4.class) public final class AgentToolTest { + private InMemorySessionService sessionService; + + @Before + public void setUp() { + sessionService = new InMemorySessionService(); + } + @Test public void fromConfig_withRegisteredAgent_returnsAgentTool() throws Exception { LlmAgent testAgent = @@ -451,12 +462,217 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio assertThat(toolContext.state()).containsEntry("test_key", "test_value"); } - private static ToolContext createToolContext(LlmAgent agent) { + @Test + public void call_withSkipSummarizationAndStateDelta_propagatesStateAndSetsSkipSummarization() + throws Exception { + AfterAgentCallback afterAgentCallback = + (callbackContext) -> { + callbackContext.state().put("test_key", "test_value"); + return Maybe.empty(); + }; + TestLlm testLlm = + createTestLlm( + LlmResponse.builder() + .content(Content.fromParts(Part.fromText("test response"))) + .build()); + LlmAgent testAgent = + createTestAgentBuilder(testLlm) + .name("agent name") + .description("agent description") + .afterAgentCallback(afterAgentCallback) + .build(); + AgentTool agentTool = AgentTool.create(testAgent, /* skipSummarization= */ true); + ToolContext toolContext = createToolContext(testAgent); + + assertThat(toolContext.state()).doesNotContainKey("test_key"); + + Map unused = + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); + + // Verify that stateDelta is propagated to the ToolContext's EventActions, otherwise + // Function.buildResponseEvent() will not include it in the response event. + assertThat(toolContext.actions().stateDelta()).containsEntry("test_key", "test_value"); + assertThat(toolContext.actions().skipSummarization()).hasValue(true); + } + + @Test + public void + declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() { + Schema inputSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "query", + Schema.builder().type("STRING").build(), + "language", + Schema.builder().type("STRING").build())) + .required(ImmutableList.of("query", "language")) + .build(); + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .inputSchema(inputSchema) + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("sequence"); + assertThat(declaration.description().get()) + .isEqualTo("Process the query through multiple steps"); + assertThat(declaration.parameters().get()).isEqualTo(inputSchema); + } + + @Test + public void declaration_sequentialAgentWithoutInputSchema_fallsBackToRequest() { + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("sequence"); + assertThat(declaration.description().get()) + .isEqualTo("Process the query through multiple steps"); + assertThat(declaration.parameters().get()) + .isEqualTo( + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("request", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("request")) + .build()); + } + + @Test + public void call_sequentialAgentWithLastSubAgentOutputSchema_successful() throws Exception { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "is_valid", + Schema.builder().type("BOOLEAN").build(), + "message", + Schema.builder().type("STRING").build())) + .required(ImmutableList.of("is_valid", "message")) + .build(); + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .build(); + LlmAgent secondAgent = + createTestAgentBuilder( + createTestLlm( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromText( + "{\"is_valid\": true, " + "\"message\": \"success\"}"))) + .build())) + .name("second_agent") + .outputSchema(outputSchema) + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + ToolContext toolContext = createToolContext(sequentialAgent); + + Map result = + agentTool.runAsync(ImmutableMap.of("request", "test"), toolContext).blockingGet(); + + assertThat(result).containsExactly("is_valid", true, "message", "success"); + } + + @Test + public void declaration_nestedSequentialAgentInputSchema_returnsDeclarationWithSchema() { + Schema inputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("deep_query", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("deep_query")) + .build(); + LlmAgent innerAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("inner_agent") + .inputSchema(inputSchema) + .build(); + SequentialAgent innerSequence = + SequentialAgent.builder() + .name("inner_sequence") + .subAgents(ImmutableList.of(innerAgent)) + .build(); + SequentialAgent outerSequence = + SequentialAgent.builder() + .name("outer_sequence") + .description("Nested sequence") + .subAgents(ImmutableList.of(innerSequence)) + .build(); + AgentTool agentTool = AgentTool.create(outerSequence); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("outer_sequence"); + assertThat(declaration.parameters().get()).isEqualTo(inputSchema); + } + + @Test + public void declaration_emptySequentialAgent_fallsBackToRequest() { + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("empty_sequence") + .description("An empty sequence") + .subAgents(ImmutableList.of()) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("empty_sequence"); + assertThat(declaration.parameters().get()) + .isEqualTo( + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("request", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("request")) + .build()); + } + + private ToolContext createToolContext(BaseAgent agent) { + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); return ToolContext.builder( InvocationContext.builder() .invocationId(InvocationContext.newInvocationContextId()) .agent(agent) - .session(Session.builder("123").build()) + .session(session) + .sessionService(sessionService) .build()) .build(); } diff --git a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java index e4d92f197..4e80ed0ff 100644 --- a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java +++ b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java @@ -41,7 +41,7 @@ public final class ExampleToolTest { /** Helper to create a minimal agent & context for testing. */ - private InvocationContext newContext() { + private InvocationContext buildInvocationContext() { TestLlm testLlm = new TestLlm(() -> Flowable.just(LlmResponse.builder().build())); LlmAgent agent = TestUtils.createTestAgent(testLlm); return TestUtils.createInvocationContext(agent); @@ -58,7 +58,7 @@ private static Example makeExample(String in, String out) { public void processLlmRequest_withInlineExamples_appendsFewShot() { ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); @@ -75,7 +75,7 @@ public void processLlmRequest_withInlineExamples_appendsFewShot() { public void processLlmRequest_withProvider_appendsFewShot() { ExampleTool tool = ExampleTool.builder().setExampleProvider(ProviderHolder.EXAMPLES).build(); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); @@ -91,12 +91,13 @@ public void processLlmRequest_withProvider_appendsFewShot() { @Test public void processLlmRequest_withEmptyUserContent_doesNotAppendFewShot() { ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); - InvocationContext ctxWithContent = newContext(); + InvocationContext ctxWithContent = buildInvocationContext(); InvocationContext ctx = InvocationContext.builder() .invocationId(ctxWithContent.invocationId()) .agent(ctxWithContent.agent()) .session(ctxWithContent.session()) + .sessionService(ctxWithContent.sessionService()) .userContent(Content.fromParts(Part.fromText(""))) .runConfig(ctxWithContent.runConfig()) .build(); @@ -120,7 +121,7 @@ public void fromConfig_withInlineExamples_buildsTool() throws Exception { "output", ImmutableList.of(Content.fromParts(Part.fromText("a")))))); ExampleTool tool = ExampleTool.fromConfig(args); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); @@ -144,7 +145,7 @@ public void fromConfig_withProviderReference_buildsTool() throws Exception { "examples", ExampleToolTest.ProviderHolder.class.getName() + ".EXAMPLES"); ExampleTool tool = ExampleTool.fromConfig(args); - InvocationContext ctx = newContext(); + InvocationContext ctx = buildInvocationContext(); LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 5816c427a..0939c6506 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -20,7 +20,9 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.events.ToolConfirmation; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -34,6 +36,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,6 +44,25 @@ /** Unit tests for {@link FunctionTool}. */ @RunWith(JUnit4.class) public final class FunctionToolTest { + private LlmAgent agent; + private InMemorySessionService sessionService; + private ToolContext toolContext; + + @Before + public void setUp() { + agent = LlmAgent.builder().name("test-agent").build(); + sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .agent(agent) + .session(session) + .sessionService(sessionService) + .invocationId("invocation-id") + .build(); + toolContext = ToolContext.builder(invocationContext).functionCallId("functionCallId").build(); + } @Test public void create_withNonSerializableParameter_raisesIllegalArgumentException() { @@ -233,11 +255,6 @@ public void create_withParameterizedList() { @Test public void call_withAllSupportedParameterTypes() throws Exception { FunctionTool tool = FunctionTool.create(Functions.class, "returnAllSupportedParametersAsMap"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync( @@ -579,11 +596,6 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception { Functions functions = new Functions(); FunctionTool tool = FunctionTool.create(functions, "nonStaticReturnAllSupportedParametersAsMap"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync( @@ -630,11 +642,6 @@ public void runAsync_withRequireConfirmation() throws Exception { Method method = Functions.class.getMethod("returnsMap"); FunctionTool tool = new FunctionTool(null, method, /* isLongRunning= */ false, /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); // First call, should request confirmation Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); @@ -663,11 +670,6 @@ public void create_instanceMethodWithConfirmation_requestsConfirmation() throws Functions functions = new Functions(); Method method = Functions.class.getMethod("nonStaticVoidReturnWithoutSchema"); FunctionTool tool = FunctionTool.create(functions, method, /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) @@ -680,11 +682,6 @@ public void create_instanceMethodWithConfirmation_requestsConfirmation() throws public void create_staticMethodWithConfirmation_requestsConfirmation() throws Exception { Method method = Functions.class.getMethod("voidReturnWithoutSchema"); FunctionTool tool = FunctionTool.create(method, /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) @@ -698,11 +695,6 @@ public void create_classMethodNameWithConfirmation_requestsConfirmation() throws FunctionTool tool = FunctionTool.create( Functions.class, "voidReturnWithoutSchema", /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) @@ -717,11 +709,6 @@ public void create_instanceMethodNameWithConfirmation_requestsConfirmation() thr FunctionTool tool = FunctionTool.create( functions, "nonStaticVoidReturnWithoutSchema", /* requireConfirmation= */ true); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); assertThat(result) diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java index c501d6a43..6f04a7ef8 100644 --- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java +++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java @@ -6,7 +6,9 @@ import static org.mockito.Mockito.when; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.tools.ToolContext; import com.google.cloud.aiplatform.v1.RagContexts; @@ -25,6 +27,7 @@ import com.google.genai.types.VertexRagStore; import com.google.genai.types.VertexRagStoreRagResource; import java.util.Map; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -39,6 +42,15 @@ public final class VertexAiRagRetrievalTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); @Mock private VertexRagServiceClient vertexRagServiceClient; + private InMemorySessionService sessionService; + private LlmAgent agent; + + @Before + public void setUp() { + sessionService = new InMemorySessionService(); + agent = LlmAgent.builder().name("test-agent").build(); + } + @Test public void runAsync_withResults_returnsContexts() throws Exception { ImmutableList ragResources = @@ -53,11 +65,7 @@ public void runAsync_withResults_returnsContexts() throws Exception { ragResources, vectorDistanceThreshold); String query = "test query"; - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); RetrieveContextsRequest expectedRequest = RetrieveContextsRequest.newBuilder() .setParent("projects/test-project/locations/us-central1") @@ -97,11 +105,7 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { ragResources, vectorDistanceThreshold); String query = "test query"; - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); RetrieveContextsRequest expectedRequest = RetrieveContextsRequest.newBuilder() .setParent("projects/test-project/locations/us-central1") @@ -143,11 +147,7 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { ragResources, vectorDistanceThreshold); LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-2-pro"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); tool.processLlmRequest(llmRequestBuilder, toolContext).blockingAwait(); @@ -209,11 +209,7 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { ragResources, vectorDistanceThreshold); LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro"); - ToolContext toolContext = - ToolContext.builder( - InvocationContext.builder().session(Session.builder("123").build()).build()) - .functionCallId("functionCallId") - .build(); + ToolContext toolContext = buildToolContext(); GenerateContentConfig initialConfig = GenerateContentConfig.builder().build(); llmRequestBuilder.config(initialConfig); @@ -241,4 +237,17 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { .build())) .build()); } + + private ToolContext buildToolContext() { + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + return ToolContext.builder( + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(agent) + .session(session) + .sessionService(sessionService) + .build()) + .build(); + } } diff --git a/dev/INTENRAL_TODOS.md b/dev/INTENRAL_TODOS.md new file mode 100644 index 000000000..ec18f572c --- /dev/null +++ b/dev/INTENRAL_TODOS.md @@ -0,0 +1,22 @@ +# Dev TODOs + +This file contains TODOs for the dev ADK module based on +[Recommendations for making ADK Java more idiomatic](http://go/idiomatic-adk-java). + +## Dev UI + +- [ ] **Conditional UI**: Add a configuration property (e.g., + `adk.web.ui.enabled`) to conditionally enable/disable serving Dev UI static + assets (in `AdkWebServer`). +- [ ] **Integration Tests**: Add E2E tests (Selenium/Playwright/HtmlUnit) for + Dev UI to verify interaction between frontend assets and Spring Boot + backend. +- [ ] **Integration Tests**: Test critical paths like loading UI, WebSocket + connection, sending/receiving messages, and rich content handling (images). + +## Production Readiness + +- [ ] **Actuators**: Enable and configure Spring Boot Actuator endpoints for + monitoring and management. +- [ ] **Actuators**: Configure startup and readiness probes for production + environments. diff --git a/pom.xml b/pom.xml index 6a1aa5af5..01dbd4201 100644 --- a/pom.xml +++ b/pom.xml @@ -37,7 +37,6 @@ tutorials/city-time-weather tutorials/live-audio-single-agent a2a - a2a/webservice @@ -406,15 +405,6 @@ - - org.sonatype.central - central-publishing-maven-plugin - 0.8.0 - true - - central - - com.spotify.fmt fmt-maven-plugin @@ -472,6 +462,34 @@ + + release-sonatype + + + + org.sonatype.central + central-publishing-maven-plugin + 0.8.0 + true + + central + + + + + + + central + Maven Central Repository + https://central.sonatype.com/api/v1/publisher + + + central + Maven Central Repository Snapshots + https://central.sonatype.com/repository/maven-snapshots/ + + + release @@ -481,14 +499,7 @@ maven-gpg-plugin 3.2.7 - ${gpg.keyname} - ${gpg.passphrase} - - --batch - --yes - --pinentry-mode - loopback - + bc @@ -546,16 +557,4 @@ https://www.google.com - - - central - Maven Central Repository - https://central.sonatype.com/api/v1/publisher - - - central - Maven Central Repository Snapshots - https://central.sonatype.com/repository/maven-snapshots/ - - - \ No newline at end of file + diff --git a/release-please-config.json b/release-please-config.json new file mode 100644 index 000000000..6b3a1302c --- /dev/null +++ b/release-please-config.json @@ -0,0 +1,11 @@ +{ + "packages": { + ".": {} + }, + "include-component-in-tag": false, + "release-type": "maven", + "extra-files": [ + "core/src/main/java/com/google/adk/Version.java", + "README.md" + ] +} diff --git a/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java b/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java index 85e8914a7..a1342a936 100644 --- a/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java +++ b/tutorials/live-audio-single-agent/src/main/java/com/google/adk/tutorials/LiveAudioSingleAgent.java @@ -30,13 +30,12 @@ public class LiveAudioSingleAgent { .model("gemini-2.0-flash-live-001") .description("A helpful weather assistant that provides weather information.") .instruction( - "You are a friendly weather assistant. When users ask about weather, " - + "you MUST call the getWeather tool with the location name. " - + "Extract the location from the user's question. " - + "ALWAYS use the getWeather tool to get accurate information - never make up weather data. " - + "After getting the tool result, provide a friendly and descriptive response. " - + "For general conversation or greetings, respond naturally and helpfully. " - + "Do NOT use code execution for anything.") + "You are a friendly weather assistant. When users ask about weather, you MUST call" + + " the getWeather tool with the location name. Extract the location from the" + + " user's question. ALWAYS use the getWeather tool to get accurate information -" + + " never make up weather data. After getting the tool result, provide a friendly" + + " and descriptive response. For general conversation or greetings, respond" + + " naturally and helpfully. Do NOT use code execution for anything.") .tools(FunctionTool.create(LiveAudioSingleAgent.class, "getWeather")) .build(); @@ -89,16 +88,17 @@ public static Map getWeather( String normalizedLocation = location.toLowerCase().trim(); - return weatherData.getOrDefault( + return weatherData.computeIfAbsent( normalizedLocation, - Map.of( - "status", - "error", - "report", - String.format( - "Weather information for '%s' is not available. Try New York, London, Tokyo, or" - + " Sydney.", - location))); + unused -> + Map.of( + "status", + "error", + "report", + String.format( + "Weather information for '%s' is not available. Try New York, London, Tokyo," + + " or Sydney.", + location))); } public static void main(String[] args) {