diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java index 6df01694a..0fbeb0a72 100644 --- a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java @@ -2,8 +2,13 @@ import com.google.adk.a2a.converters.EventConverter; import com.google.adk.a2a.converters.PartConverter; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.RunConfig; +import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.memory.BaseMemoryService; +import com.google.adk.plugins.Plugin; import com.google.adk.runner.Runner; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; @@ -21,6 +26,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -41,29 +47,98 @@ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor private static final RunConfig DEFAULT_RUN_CONFIG = RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); - private final Runner runner; private final Map activeTasks = new ConcurrentHashMap<>(); + private final Runner.Builder runnerBuilder; + private final RunConfig runConfig; - private AgentExecutor(Runner runner) { - this.runner = runner; + private AgentExecutor( + App app, + BaseAgent agent, + String appName, + BaseArtifactService artifactService, + BaseSessionService sessionService, + BaseMemoryService memoryService, + List plugins, + RunConfig runConfig) { + this.runnerBuilder = + Runner.builder() + .agent(agent) + .appName(appName) + .artifactService(artifactService) + .sessionService(sessionService) + .memoryService(memoryService) + .plugins(plugins); + if (app != null) { + this.runnerBuilder.app(app); + } + // Check that the runner is configured correctly and can be built. + var unused = runnerBuilder.build(); + this.runConfig = runConfig == null ? DEFAULT_RUN_CONFIG : runConfig; } /** Builder for {@link AgentExecutor}. */ public static class Builder { - private Runner runner; + private App app; + private BaseAgent agent; + private String appName; + private BaseArtifactService artifactService; + private BaseSessionService sessionService; + private BaseMemoryService memoryService; + private List plugins = ImmutableList.of(); + private RunConfig runConfig; + + @CanIgnoreReturnValue + public Builder app(App app) { + this.app = app; + return this; + } + + @CanIgnoreReturnValue + public Builder agent(BaseAgent agent) { + this.agent = agent; + return this; + } + + @CanIgnoreReturnValue + public Builder appName(String appName) { + this.appName = appName; + return this; + } + + @CanIgnoreReturnValue + public Builder artifactService(BaseArtifactService artifactService) { + this.artifactService = artifactService; + return this; + } + + @CanIgnoreReturnValue + public Builder sessionService(BaseSessionService sessionService) { + this.sessionService = sessionService; + return this; + } @CanIgnoreReturnValue - public Builder runner(Runner runner) { - this.runner = runner; + public Builder memoryService(BaseMemoryService memoryService) { + this.memoryService = memoryService; + return this; + } + + @CanIgnoreReturnValue + public Builder plugins(List plugins) { + this.plugins = plugins; + return this; + } + + @CanIgnoreReturnValue + public Builder runConfig(RunConfig runConfig) { + this.runConfig = runConfig; return this; } @CanIgnoreReturnValue public AgentExecutor build() { - if (runner == null) { - throw new IllegalStateException("Runner must be provided."); - } - return new AgentExecutor(runner); + return new AgentExecutor( + app, agent, appName, artifactService, sessionService, memoryService, plugins, runConfig); } } @@ -96,13 +171,14 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { EventProcessor p = new EventProcessor(); Content content = PartConverter.messageToContent(message); + Runner runner = runnerBuilder.build(); taskDisposables.add( - prepareSession(ctx, runner.sessionService()) + prepareSession(ctx, runner.appName(), runner.sessionService()) .flatMapPublisher( session -> { updater.startWork(); - return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG); + return runner.runAsync(getUserId(ctx), session.id(), content, runConfig); }) .subscribe( event -> { @@ -130,13 +206,14 @@ private String getUserId(RequestContext ctx) { return USER_ID_PREFIX + ctx.getContextId(); } - private Maybe prepareSession(RequestContext ctx, BaseSessionService service) { + private Maybe prepareSession( + RequestContext ctx, String appName, BaseSessionService service) { return service - .getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty()) + .getSession(appName, getUserId(ctx), ctx.getContextId(), Optional.empty()) .switchIfEmpty( Maybe.defer( () -> { - return service.createSession(runner.appName(), getUserId(ctx)).toMaybe(); + return service.createSession(appName, getUserId(ctx)).toMaybe(); })); } diff --git a/a2a/src/test/java/com/google/adk/a2a/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/AgentExecutorTest.java new file mode 100644 index 000000000..44daf13d1 --- /dev/null +++ b/a2a/src/test/java/com/google/adk/a2a/AgentExecutorTest.java @@ -0,0 +1,82 @@ +package com.google.adk.a2a; + +import static org.junit.Assert.assertThrows; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.apps.App; +import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.events.Event; +import com.google.adk.sessions.InMemorySessionService; +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Flowable; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AgentExecutorTest { + + private TestAgent testAgent; + + @Before + public void setUp() { + testAgent = new TestAgent(); + } + + @Test + public void createAgentExecutor_noAgent_succeeds() { + var unused = + new AgentExecutor.Builder() + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + } + + @Test + public void createAgentExecutor_withAgentAndApp_throwsException() { + assertThrows( + IllegalStateException.class, + () -> { + new AgentExecutor.Builder() + .agent(testAgent) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + }); + } + + @Test + public void createAgentExecutor_withEmptyAgentAndApp_throwsException() { + assertThrows( + IllegalStateException.class, + () -> { + new AgentExecutor.Builder() + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + }); + } + + private static final class TestAgent extends BaseAgent { + private final Flowable eventsToEmit = Flowable.empty(); + + TestAgent() { + // BaseAgent constructor: name, description, examples, tools, model + super("test_agent", "test", ImmutableList.of(), null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return eventsToEmit; + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return eventsToEmit; + } + } +} diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index 49298451b..a29783769 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -92,14 +92,20 @@ public Single> listArtifacts() { .map(ListArtifactsResponse::filenames); } + /** Loads the latest version of an artifact from the service. */ + public Maybe loadArtifact(String filename) { + return loadArtifact(filename, Optional.empty()); + } + + /** Loads a specific version of an artifact from the service. */ + public Maybe loadArtifact(String filename, int version) { + return loadArtifact(filename, Optional.of(version)); + } + /** - * Loads an artifact from the artifact service associated with the current session. - * - * @param filename Artifact file name. - * @param version Artifact version (optional). - * @return loaded part, or empty if not found. - * @throws IllegalStateException if the artifact service is not initialized. + * @deprecated Use {@link #loadArtifact(String)} or {@link #loadArtifact(String, int)} instead. */ + @Deprecated public Maybe loadArtifact(String filename, Optional version) { if (invocationContext.artifactService() == null) { throw new IllegalStateException("Artifact service is not initialized."); diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 1893fb162..ee24cae4d 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -755,21 +755,36 @@ public Single> canonicalGlobalInstruction(ReadonlyCon throw new IllegalStateException("Unknown Instruction subtype: " + globalInstruction.getClass()); } + /** + * @deprecated Use {@link #canonicalTools(ReadonlyContext)} instead. + */ + @Deprecated + public Flowable canonicalTools(Optional context) { + return canonicalTools(context.orElse(null)); + } + /** * Constructs the list of tools for this agent based on the {@link #tools} field. * - *

This method is only for use by Agent Development Kit. + * @return The resolved list of tools as a {@link Single} wrapped list of {@link BaseTool}. + */ + public Flowable canonicalTools() { + return canonicalTools((ReadonlyContext) null); + } + + /** + * Constructs the list of tools for this agent based on the {@link #tools} field. * * @param context The context to retrieve the session state. * @return The resolved list of tools as a {@link Single} wrapped list of {@link BaseTool}. */ - public Flowable canonicalTools(Optional context) { + public Flowable canonicalTools(@Nullable ReadonlyContext context) { List> toolFlowables = new ArrayList<>(); for (Object toolOrToolset : toolsUnion) { if (toolOrToolset instanceof BaseTool baseTool) { toolFlowables.add(Flowable.just(baseTool)); } else if (toolOrToolset instanceof BaseToolset baseToolset) { - toolFlowables.add(baseToolset.getTools(context.orElse(null))); + toolFlowables.add(baseToolset.getTools(context)); } else { throw new IllegalArgumentException( "Object in tools list is not of a supported type: " @@ -779,16 +794,6 @@ public Flowable canonicalTools(Optional context) { return Flowable.concat(toolFlowables); } - /** Overload of canonicalTools that defaults to an empty context. */ - public Flowable canonicalTools() { - return canonicalTools(Optional.empty()); - } - - /** Convenience overload of canonicalTools that accepts a non-optional ReadonlyContext. */ - public Flowable canonicalTools(ReadonlyContext context) { - return canonicalTools(Optional.ofNullable(context)); - } - public Instruction instruction() { return instruction; } diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index b6a3cee23..32ef9ff4d 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -55,22 +55,26 @@ Single saveArtifact( default Single saveAndReloadArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { return saveArtifact(appName, userId, sessionId, filename, artifact) - .flatMap( - version -> - loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .toSingle()); + .flatMap(version -> loadArtifact(appName, userId, sessionId, filename, version).toSingle()); + } + + /** Loads the latest version of an artifact from the service. */ + default Maybe loadArtifact( + String appName, String userId, String sessionId, String filename) { + return loadArtifact(appName, userId, sessionId, filename, Optional.empty()); + } + + /** Loads a specific version of an artifact from the service. */ + default Maybe loadArtifact( + String appName, String userId, String sessionId, String filename, int version) { + return loadArtifact(appName, userId, sessionId, filename, Optional.of(version)); } /** - * Gets an artifact. - * - * @param appName the app name - * @param userId the user ID - * @param sessionId the session ID - * @param filename the filename - * @param version Optional version number. If null, loads the latest version. - * @return the artifact or empty if not found + * @deprecated Use {@link #loadArtifact(String, String, String, String)} or {@link + * #loadArtifact(String, String, String, String, int)} instead. */ + @Deprecated Maybe loadArtifact( String appName, String userId, String sessionId, String filename, Optional version); diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index b9bc49a02..e31d50327 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -28,12 +28,12 @@ import com.google.common.base.Splitter; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import com.google.genai.types.FileData; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -135,22 +135,25 @@ public Maybe loadArtifact( .flatMapMaybe( versions -> versions.isEmpty() ? Maybe.empty() : Maybe.just(max(versions)))) .flatMap( - versionToLoad -> { - String blobName = getBlobName(appName, userId, sessionId, filename, versionToLoad); - BlobId blobId = BlobId.of(bucketName, blobName); + versionToLoad -> + Maybe.fromCallable( + () -> { + String blobName = + getBlobName(appName, userId, sessionId, filename, versionToLoad); + BlobId blobId = BlobId.of(bucketName, blobName); - try { - Blob blob = storageClient.get(blobId); - if (blob == null || !blob.exists()) { - return Maybe.empty(); - } - byte[] data = blob.getContent(); - String mimeType = blob.getContentType(); - return Maybe.just(Part.fromBytes(data, mimeType)); - } catch (StorageException e) { - return Maybe.empty(); - } - }); + try { + Blob blob = storageClient.get(blobId); + if (blob == null || !blob.exists()) { + return null; + } + byte[] data = blob.getContent(); + String mimeType = blob.getContentType(); + return Part.fromBytes(data, mimeType); + } catch (StorageException e) { + return null; + } + })); } /** @@ -164,34 +167,38 @@ public Maybe loadArtifact( @Override public Single listArtifactKeys( String appName, String userId, String sessionId) { - Set filenames = new HashSet<>(); + return Single.fromCallable( + () -> { + Set filenames = new HashSet<>(); - // List session-specific files - String sessionPrefix = String.format("%s/%s/%s/", appName, userId, sessionId); - try { - for (Blob blob : - storageClient.list(bucketName, BlobListOption.prefix(sessionPrefix)).iterateAll()) { - List parts = Splitter.on('/').splitToList(blob.getName()); - filenames.add(parts.get(3)); // appName/userId/sessionId/filename/version - } - } catch (StorageException e) { - throw new VerifyException("Failed to list session artifacts from GCS", e); - } + // List session-specific files + String sessionPrefix = String.format("%s/%s/%s/", appName, userId, sessionId); + try { + for (Blob blob : + storageClient.list(bucketName, BlobListOption.prefix(sessionPrefix)).iterateAll()) { + List parts = Splitter.on('/').splitToList(blob.getName()); + filenames.add(parts.get(3)); // appName/userId/sessionId/filename/version + } + } catch (StorageException e) { + throw new VerifyException("Failed to list session artifacts from GCS", e); + } - // List user-namespace files - String userPrefix = String.format("%s/%s/user/", appName, userId); - try { - for (Blob blob : - storageClient.list(bucketName, BlobListOption.prefix(userPrefix)).iterateAll()) { - List parts = Splitter.on('/').splitToList(blob.getName()); - filenames.add(parts.get(3)); // appName/userId/user/filename/version - } - } catch (StorageException e) { - throw new VerifyException("Failed to list user artifacts from GCS", e); - } + // List user-namespace files + String userPrefix = String.format("%s/%s/user/", appName, userId); + try { + for (Blob blob : + storageClient.list(bucketName, BlobListOption.prefix(userPrefix)).iterateAll()) { + List parts = Splitter.on('/').splitToList(blob.getName()); + filenames.add(parts.get(3)); // appName/userId/user/filename/version + } + } catch (StorageException e) { + throw new VerifyException("Failed to list user artifacts from GCS", e); + } - return Single.just( - ListArtifactsResponse.builder().filenames(ImmutableList.sortedCopyOf(filenames)).build()); + return ListArtifactsResponse.builder() + .filenames(ImmutableList.sortedCopyOf(filenames)) + .build(); + }); } /** @@ -206,22 +213,30 @@ public Single listArtifactKeys( @Override public Completable deleteArtifact( String appName, String userId, String sessionId, String filename) { - ImmutableList versions = - listVersions(appName, userId, sessionId, filename).blockingGet(); - List blobIdsToDelete = new ArrayList<>(); - for (int version : versions) { - String blobName = getBlobName(appName, userId, sessionId, filename, version); - blobIdsToDelete.add(BlobId.of(bucketName, blobName)); - } + return listVersions(appName, userId, sessionId, filename) + .flatMapCompletable( + versions -> { + if (versions.isEmpty()) { + return Completable.complete(); + } + ImmutableList blobIdsToDelete = + versions.stream() + .map( + version -> + BlobId.of( + bucketName, + getBlobName(appName, userId, sessionId, filename, version))) + .collect(ImmutableList.toImmutableList()); - if (!blobIdsToDelete.isEmpty()) { - try { - var unused = storageClient.delete(blobIdsToDelete); - } catch (StorageException e) { - throw new VerifyException("Failed to delete artifact versions from GCS", e); - } - } - return Completable.complete(); + return Completable.fromAction( + () -> { + try { + var unused = storageClient.delete(blobIdsToDelete); + } catch (StorageException e) { + throw new VerifyException("Failed to delete artifact versions from GCS", e); + } + }); + }); } /** @@ -236,20 +251,29 @@ public Completable deleteArtifact( @Override public Single> listVersions( String appName, String userId, String sessionId, String filename) { - String prefix = getBlobPrefix(appName, userId, sessionId, filename); - List versions = new ArrayList<>(); - try { - for (Blob blob : storageClient.list(bucketName, BlobListOption.prefix(prefix)).iterateAll()) { - String name = blob.getName(); - int versionDelimiterIndex = name.lastIndexOf('/'); // immediately before the version number - if (versionDelimiterIndex != -1 && versionDelimiterIndex < name.length() - 1) { - versions.add(Integer.parseInt(name.substring(versionDelimiterIndex + 1))); - } - } - return Single.just(ImmutableList.sortedCopyOf(versions)); - } catch (StorageException e) { - return Single.just(ImmutableList.of()); - } + return Single.fromCallable( + () -> { + String prefix = getBlobPrefix(appName, userId, sessionId, filename); + try { + return Streams.stream( + storageClient.list(bucketName, BlobListOption.prefix(prefix)).iterateAll()) + .map(Blob::getName) + .map( + name -> { + int versionDelimiterIndex = name.lastIndexOf('/'); + return versionDelimiterIndex != -1 + && versionDelimiterIndex < name.length() - 1 + ? Optional.of(name.substring(versionDelimiterIndex + 1)) + : Optional.empty(); + }) + .flatMap(Optional::stream) + .map(Integer::parseInt) + .sorted() + .collect(ImmutableList.toImmutableList()); + } catch (StorageException e) { + return ImmutableList.of(); + } + }); } @Override @@ -291,35 +315,39 @@ private Single saveArtifactAndReturnBlob( String appName, String userId, String sessionId, String filename, Part artifact) { return listVersions(appName, userId, sessionId, filename) .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) - .map( - nextVersion -> { - if (artifact.inlineData().isEmpty()) { - throw new IllegalArgumentException("Saveable artifact must have inline data."); - } + .flatMap( + nextVersion -> + Single.fromCallable( + () -> { + if (artifact.inlineData().isEmpty()) { + throw new IllegalArgumentException( + "Saveable artifact must have inline data."); + } - String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); - BlobId blobId = BlobId.of(bucketName, blobName); + String blobName = + getBlobName(appName, userId, sessionId, filename, nextVersion); + BlobId blobId = BlobId.of(bucketName, blobName); - BlobInfo blobInfo = - BlobInfo.newBuilder(blobId) - .setContentType(artifact.inlineData().get().mimeType().orElse(null)) - .build(); + BlobInfo blobInfo = + BlobInfo.newBuilder(blobId) + .setContentType(artifact.inlineData().get().mimeType().orElse(null)) + .build(); - try { - byte[] dataToSave = - artifact - .inlineData() - .get() - .data() - .orElseThrow( - () -> - new IllegalArgumentException( - "Saveable artifact data must be non-empty.")); - Blob blob = storageClient.create(blobInfo, dataToSave); - return SaveResult.create(blob, nextVersion); - } catch (StorageException e) { - throw new VerifyException("Failed to save artifact to GCS", e); - } - }); + try { + byte[] dataToSave = + artifact + .inlineData() + .get() + .data() + .orElseThrow( + () -> + new IllegalArgumentException( + "Saveable artifact data must be non-empty.")); + Blob blob = storageClient.create(blobInfo, dataToSave); + return SaveResult.create(blob, nextVersion); + } catch (StorageException e) { + throw new VerifyException("Failed to save artifact to GCS", e); + } + })); } } diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 5808f7083..8c8ec2af8 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -129,10 +129,7 @@ public Single> listVersions( public Single saveAndReloadArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { return saveArtifact(appName, userId, sessionId, filename, artifact) - .flatMap( - version -> - loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .toSingle()); + .flatMap(version -> loadArtifact(appName, userId, sessionId, filename, version).toSingle()); } private Map> getArtifactsMap(String appName, String userId, String sessionId) { diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6d8c698dd..bf25acfc7 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -22,6 +22,7 @@ import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -383,7 +384,7 @@ public Builder compaction(EventCompaction value) { @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); - this.stateDelta.putAll(other.stateDelta()); + other.stateDelta().forEach((key, value) -> stateDelta.merge(key, value, Builder::deepMerge)); this.artifactDelta.putAll(other.artifactDelta()); this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); @@ -395,6 +396,34 @@ public Builder merge(EventActions other) { return this; } + private static Object deepMerge(Object target, Object source) { + if (!(target instanceof Map) || !(source instanceof Map)) { + // If one of them is not a map, the source value overwrites the target. + return source; + } + + Map targetMap = (Map) target; + Map sourceMap = (Map) source; + + if (!targetMap.isEmpty() && !sourceMap.isEmpty()) { + Object targetKey = targetMap.keySet().iterator().next(); + Object sourceKey = sourceMap.keySet().iterator().next(); + if (targetKey != null + && sourceKey != null + && !targetKey.getClass().equals(sourceKey.getClass())) { + throw new IllegalArgumentException( + String.format( + "Cannot merge maps with different key types: %s vs %s", + targetKey.getClass().getName(), sourceKey.getClass().getName())); + } + } + + // Create a new map to prevent UnsupportedOperationException from immutable maps + Map mergedMap = new ConcurrentHashMap<>(targetMap); + sourceMap.forEach((key, value) -> mergedMap.merge(key, value, Builder::deepMerge)); + return mergedMap; + } + public EventActions build() { return new EventActions(this); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 6ca49ee62..549652e86 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -611,7 +611,8 @@ private Flowable buildPostprocessingEvents( if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() && !updatedResponse.interrupted().orElse(false) - && !updatedResponse.turnComplete().orElse(false)) { + && !updatedResponse.turnComplete().orElse(false) + && updatedResponse.usageMetadata().isEmpty()) { return processorEvents; } diff --git a/core/src/main/java/com/google/adk/models/GeminiUtil.java b/core/src/main/java/com/google/adk/models/GeminiUtil.java index 319226d69..0ef0a50fd 100644 --- a/core/src/main/java/com/google/adk/models/GeminiUtil.java +++ b/core/src/main/java/com/google/adk/models/GeminiUtil.java @@ -59,7 +59,9 @@ public static LlmRequest prepareGenenerateContentRequest( * Prepares an {@link LlmRequest} for the GenerateContent API. * *

This method can optionally sanitize the request and ensures that the last content part is - * from the user to prompt a model response. It also strips out any parts marked as "thoughts". + * from the user to prompt a model response. It also strips out any parts marked as "thoughts" and + * removes client-side function call IDs as some LLM APIs reject requests with client-side + * function call IDs. * * @param llmRequest The original {@link LlmRequest}. * @param sanitize Whether to sanitize the request to be compatible with the Gemini API backend. @@ -70,6 +72,7 @@ public static LlmRequest prepareGenenerateContentRequest( if (sanitize) { llmRequest = sanitizeRequestForGeminiApi(llmRequest); } + llmRequest = removeClientFunctionCallId(llmRequest); List contents = ensureModelResponse(llmRequest.contents()); if (stripThoughts) { contents = stripThoughts(contents); @@ -136,6 +139,50 @@ public static LlmRequest sanitizeRequestForGeminiApi(LlmRequest llmRequest) { return requestBuilder.contents(updatedContents).build(); } + /** + * Removes client-side function call IDs from the request. + * + *

Client-side function call IDs are internal to the ADK and should not be sent to the model. + * This method iterates through the contents and parts, removing the ID from any {@link + * com.google.genai.types.FunctionCall} or {@link com.google.genai.types.FunctionResponse} parts. + * + * @param llmRequest The request to process. + * @return A new {@link LlmRequest} with function call IDs removed. + */ + public static LlmRequest removeClientFunctionCallId(LlmRequest llmRequest) { + if (llmRequest.contents().isEmpty()) { + return llmRequest; + } + + ImmutableList updatedContents = + llmRequest.contents().stream() + .map( + content -> + content.toBuilder() + .parts( + content.parts().orElse(ImmutableList.of()).stream() + .map(GeminiUtil::removeClientFunctionCallIdFromPart) + .collect(toImmutableList())) + .build()) + .collect(toImmutableList()); + + return llmRequest.toBuilder().contents(updatedContents).build(); + } + + private static Part removeClientFunctionCallIdFromPart(Part part) { + if (part.functionCall().isPresent() && part.functionCall().get().id().isPresent()) { + return part.toBuilder() + .functionCall(part.functionCall().get().toBuilder().clearId().build()) + .build(); + } + if (part.functionResponse().isPresent() && part.functionResponse().get().id().isPresent()) { + return part.toBuilder() + .functionResponse(part.functionResponse().get().toBuilder().clearId().build()) + .build(); + } + return part; + } + /** * Ensures that the content is conducive to prompting a model response by ensuring the last * content part is from the user. @@ -213,7 +260,7 @@ public static boolean shouldEmitAccumulatedText(LlmResponse currentLlmResponse) } /** Removes any `Part` that contains only a `thought` from the content list. */ - public static List stripThoughts(List originalContents) { + public static ImmutableList stripThoughts(List originalContents) { return originalContents.stream() .map( content -> { diff --git a/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java b/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java index 8689849c2..12ec27169 100644 --- a/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java +++ b/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java @@ -79,15 +79,8 @@ public Completable processLlmRequest( List existingTools = configBuilder.build().tools().orElse(ImmutableList.of()); ImmutableList.Builder updatedToolsBuilder = ImmutableList.builder(); updatedToolsBuilder.addAll(existingTools); - - String model = llmRequestBuilder.build().model().orElse(null); - if (model != null && !model.startsWith("gemini-1")) { - updatedToolsBuilder.add(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); - configBuilder.tools(updatedToolsBuilder.build()); - } else { - return Completable.error( - new IllegalArgumentException("Google Maps tool is not supported for model " + model)); - } + updatedToolsBuilder.add(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); + configBuilder.tools(updatedToolsBuilder.build()); llmRequestBuilder.config(configBuilder.build()); return Completable.complete(); diff --git a/core/src/main/java/com/google/adk/tools/UrlContextTool.java b/core/src/main/java/com/google/adk/tools/UrlContextTool.java index 5fe072d76..fe7f9c77e 100644 --- a/core/src/main/java/com/google/adk/tools/UrlContextTool.java +++ b/core/src/main/java/com/google/adk/tools/UrlContextTool.java @@ -25,8 +25,8 @@ import java.util.List; /** - * A built-in tool that is automatically invoked by Gemini 2 models to retrieve information from the - * given URLs. + * A built-in tool that is automatically invoked by Gemini 2 and 3 models to retrieve information + * from the given URLs. * *

This tool operates internally within the model and does not require or perform local code * execution. @@ -62,7 +62,7 @@ public Completable processLlmRequest( updatedToolsBuilder.addAll(existingTools); String model = llmRequestBuilder.build().model().get(); - if (model != null && model.startsWith("gemini-2")) { + if (model != null && (model.startsWith("gemini-2") || model.startsWith("gemini-3"))) { updatedToolsBuilder.add(Tool.builder().urlContext(UrlContext.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); } else { diff --git a/core/src/main/java/com/google/adk/tools/computeruse/BaseComputer.java b/core/src/main/java/com/google/adk/tools/computeruse/BaseComputer.java new file mode 100644 index 000000000..3ddb91963 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/BaseComputer.java @@ -0,0 +1,99 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import com.google.adk.tools.Annotations.Schema; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.time.Duration; +import java.util.List; + +/** + * Defines an interface for computer environments. + * + *

This interface defines the standard methods for controlling computer environments, including + * web browsers and other interactive systems. + */ +public interface BaseComputer { + + /** Returns the screen size of the environment. */ + Single screenSize(); + + /** Opens the web browser. */ + Single openWebBrowser(); + + /** Clicks at a specific x, y coordinate on the webpage. */ + Single clickAt(@Schema(name = "x") int x, @Schema(name = "y") int y); + + /** Hovers at a specific x, y coordinate on the webpage. */ + Single hoverAt(@Schema(name = "x") int x, @Schema(name = "y") int y); + + /** Types text at a specific x, y coordinate. */ + Single typeTextAt( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "text") String text, + @Schema(name = "press_enter", optional = true) Boolean pressEnter, + @Schema(name = "clear_before_typing", optional = true) Boolean clearBeforeTyping); + + /** Scrolls the entire webpage in a direction. */ + Single scrollDocument(@Schema(name = "direction") String direction); + + /** Scrolls at a specific x, y coordinate by magnitude. */ + Single scrollAt( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "direction") String direction, + @Schema(name = "magnitude") int magnitude); + + /** Waits for specified duration. */ + Single wait(@Schema(name = "duration") Duration duration); + + /** Navigates back. */ + Single goBack(); + + /** Navigates forward. */ + Single goForward(); + + /** Jumps to search. */ + Single search(); + + /** Navigates to URL. */ + Single navigate(@Schema(name = "url") String url); + + /** Presses key combination. */ + Single keyCombination(@Schema(name = "keys") List keys); + + /** Drag and drop. */ + Single dragAndDrop( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "destination_x") int destinationX, + @Schema(name = "destination_y") int destinationY); + + /** Returns current state. */ + Single currentState(); + + /** Initialize the computer. */ + Completable initialize(); + + /** Cleanup resources. */ + Completable close(); + + /** Returns the environment. */ + Single environment(); +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerEnvironment.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerEnvironment.java new file mode 100644 index 000000000..2c897c794 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerEnvironment.java @@ -0,0 +1,23 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +/** Enum for computer environments. */ +public enum ComputerEnvironment { + ENVIRONMENT_UNSPECIFIED, + ENVIRONMENT_BROWSER +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java new file mode 100644 index 000000000..4f3be46c2 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java @@ -0,0 +1,108 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * Represents the current state of the computer environment. + * + *

Attributes: screenshot: The screenshot in PNG format as bytes. url: The current URL of the + * webpage being displayed. + */ +public final class ComputerState { + private final byte[] screenshot; + private final Optional url; + + @JsonCreator + private ComputerState( + @JsonProperty("screenshot") byte[] screenshot, @JsonProperty("url") Optional url) { + this.screenshot = screenshot.clone(); + this.url = url; + } + + @JsonProperty("screenshot") + public byte[] screenshot() { + return screenshot.clone(); + } + + @JsonProperty("url") + public Optional url() { + return url; + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link ComputerState}. */ + public static final class Builder { + private byte[] screenshot; + private Optional url = Optional.empty(); + + @CanIgnoreReturnValue + public Builder screenshot(byte[] screenshot) { + this.screenshot = screenshot.clone(); + return this; + } + + @CanIgnoreReturnValue + public Builder url(Optional url) { + this.url = url; + return this; + } + + @CanIgnoreReturnValue + public Builder url(String url) { + this.url = Optional.ofNullable(url); + return this; + } + + public ComputerState build() { + return new ComputerState(screenshot, url); + } + } + + public static ComputerState create(byte[] screenshot, String url) { + return builder().screenshot(screenshot).url(url).build(); + } + + public static ComputerState create(byte[] screenshot) { + return builder().screenshot(screenshot).build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ComputerState that)) { + return false; + } + return Objects.deepEquals(screenshot, that.screenshot) && Objects.equals(url, that.url); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(screenshot), url); + } +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseTool.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseTool.java new file mode 100644 index 000000000..cedf7f35c --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseTool.java @@ -0,0 +1,125 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static java.lang.String.format; + +import com.google.adk.tools.FunctionTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Method; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A tool that wraps computer control functions for use with LLMs. + * + *

This tool automatically normalizes coordinates from a virtual coordinate space (by default + * 1000x1000) to the actual screen size. + */ +public class ComputerUseTool extends FunctionTool { + + private static final Logger logger = LoggerFactory.getLogger(ComputerUseTool.class); + + private final int[] screenSize; + private final int[] coordinateSpace; + + public ComputerUseTool(Object instance, Method func, int[] screenSize, int[] virtualScreenSize) { + super(instance, func, /* isLongRunning= */ false); + this.screenSize = screenSize; + this.coordinateSpace = virtualScreenSize; + } + + private int normalize(Object object, String coordinateName, int index) { + if (!(object instanceof Number number)) { + throw new IllegalArgumentException(format("%s coordinate must be numeric", coordinateName)); + } + double coordinate = number.doubleValue(); + int normalized = (int) (coordinate / coordinateSpace[index] * screenSize[index]); + // Clamp to screen bounds + int clamped = Math.max(0, Math.min(normalized, screenSize[index] - 1)); + logger.atDebug().log( + format( + "%s: %.2f, normalized %s: %d, screen %s size: %d, coordinate-space %s size: %d, " + + "clamped %s: %d", + coordinateName, + coordinate, + coordinateName, + normalized, + coordinateName, + screenSize[index], + coordinateName, + coordinateSpace[index], + coordinateName, + clamped)); + return clamped; + } + + private int normalizeX(Object xObj) { + return normalize(xObj, "x", 0); + } + + private int normalizeY(Object yObj) { + return normalize(yObj, "y", 1); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + Map normalizedArgs = new HashMap<>(args); + + if (args.containsKey("x")) { + normalizedArgs.put("x", normalizeX(args.get("x"))); + } + if (args.containsKey("y")) { + normalizedArgs.put("y", normalizeY(args.get("y"))); + } + if (args.containsKey("destination_x")) { + normalizedArgs.put("destination_x", normalizeX(args.get("destination_x"))); + } + if (args.containsKey("destination_y")) { + normalizedArgs.put("destination_y", normalizeY(args.get("destination_y"))); + } + + return super.runAsync(normalizedArgs, toolContext) + .map( + result -> { + // If the underlying tool method returned a structure containing a "screenshot" field + // (e.g., a ComputerState object), FunctionTool.runAsync will have converted it to a + // Map. This post-processing step transforms the byte array "screenshot" field into + // an "image" map with a mimetype and Base64 encoded data, as expected by some + // consuming systems. + if (result.containsKey("screenshot") && result.get("screenshot") instanceof byte[]) { + byte[] screenshot = (byte[]) result.get("screenshot"); + ImmutableMap imageMap = + ImmutableMap.of( + "mimetype", + "image/png", + "data", + Base64.getEncoder().encodeToString(screenshot)); + Map finalResult = new HashMap<>(result); + finalResult.remove("screenshot"); + finalResult.put("image", imageMap); + return finalResult; + } + return result; + }); + } +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseToolset.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseToolset.java new file mode 100644 index 000000000..6984f02fd --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseToolset.java @@ -0,0 +1,181 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.genai.types.ComputerUse; +import com.google.genai.types.Environment; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Tool; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A toolset that provides computer use capabilities. + * + *

It automatically discovers and wraps methods from a {@link BaseComputer} implementation. + */ +public class ComputerUseToolset implements BaseToolset { + + private static final Logger logger = LoggerFactory.getLogger(ComputerUseToolset.class); + + private static final ImmutableSet EXCLUDED_METHODS = + ImmutableSet.of( + "screenSize", + "environment", + "close", + "initialize", + "currentState", + "getClass", + "equals", + "hashCode", + "toString", + "wait", + "notify", + "notifyAll"); + + private final BaseComputer computer; + private final int[] virtualScreenSize; + private List tools; + private boolean initialized = false; + + public ComputerUseToolset(BaseComputer computer) { + this(computer, new int[] {1000, 1000}); + } + + public ComputerUseToolset(BaseComputer computer, int[] virtualScreenSize) { + this.computer = computer; + this.virtualScreenSize = virtualScreenSize; + } + + private synchronized Completable ensureInitialized() { + if (initialized) { + return Completable.complete(); + } + return computer + .initialize() + .doOnComplete( + () -> { + initialized = true; + }); + } + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return ensureInitialized() + .andThen(computer.screenSize()) + .flatMapPublisher( + actualScreenSize -> { + if (tools == null) { + tools = new ArrayList<>(); + for (Method method : BaseComputer.class.getMethods()) { + if (!EXCLUDED_METHODS.contains(method.getName())) { + tools.add( + new ComputerUseTool(computer, method, actualScreenSize, virtualScreenSize)); + } + } + } + return Flowable.fromIterable(tools); + }); + } + + @Override + public void close() throws Exception { + computer.close().blockingAwait(); + } + + /** Adds computer use configuration to the LLM request. */ + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return getTools(null) // Fetch tools to ensure they are added to the list + .toList() + .flatMapCompletable( + tools -> { + return Completable.concat( + tools.stream() + .map(t -> t.processLlmRequest(llmRequestBuilder, toolContext)) + .collect(toImmutableList())) + .andThen( + computer + .environment() + .flatMapCompletable( + env -> { + configureComputerUseIfNeeded(llmRequestBuilder, env); + return Completable.complete(); + })); + }); + } + + /** + * Returns the {@link Environment.Known} enum for the given {@link ComputerEnvironment}. If the + * computer environment is not found or not supported, defaults to {@link + * Environment.Known.ENVIRONMENT_BROWSER}. + * + * @param computerEnvironment The {@link ComputerEnvironment} to convert. + * @return The corresponding {@link Environment.Known} enum. + */ + private static Environment.Known getEnvironment(ComputerEnvironment computerEnvironment) { + try { + return Environment.Known.valueOf(computerEnvironment.name()); + } catch (IllegalArgumentException e) { + return Environment.Known.ENVIRONMENT_BROWSER; + } + } + + /** + * Configures the computer use tool in the LLM request if it is not already configured. + * + * @param computerEnvironment The environment to configure the computer use tool for. + * @param llmRequestBuilder The LLM request builder to add the computer use tool to. + */ + private static void configureComputerUseIfNeeded( + LlmRequest.Builder llmRequestBuilder, ComputerEnvironment computerEnvironment) { + // Get the current config from the LLM request + GenerateContentConfig config = + llmRequestBuilder.config().orElse(GenerateContentConfig.builder().build()); + + // Check if computer use is already configured + if (config.tools().orElse(ImmutableList.of()).stream() + .anyMatch(t -> t.computerUse().isPresent())) { + logger.debug("Computer use already configured"); + return; + } + + // Configure the computer + Environment.Known knownEnv = getEnvironment(computerEnvironment); + Tool computerUseTool = + Tool.builder().computerUse(ComputerUse.builder().environment(knownEnv).build()).build(); + // Add the computer use tool to the list of tools in the config + List currentTools = new ArrayList<>(config.tools().orElse(ImmutableList.of())); + currentTools.add(computerUseTool); + llmRequestBuilder.config(config.toBuilder().tools(ImmutableList.copyOf(currentTools)).build()); + logger.debug("Added computer use tool with environment: {}", knownEnv); + } +} diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 40493bf3a..88abd60c4 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -16,6 +16,7 @@ package com.google.adk.artifacts; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -28,6 +29,8 @@ import com.google.cloud.storage.BlobInfo; import com.google.cloud.storage.Storage; import com.google.cloud.storage.Storage.BlobListOption; +import com.google.cloud.storage.StorageException; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; @@ -41,6 +44,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -233,8 +237,10 @@ public void list_noFiles_returnsEmpty() { String sessionPrefix = String.format("%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID); String userPrefix = String.format("%s/%s/user/", APP_NAME, USER_ID); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockSessionPage = mock(Page.class); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockUserPage = mock(Page.class); when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) @@ -262,8 +268,10 @@ public void list_withFiles_returnsCorrectFilenames() { Blob blobS2V0 = mockBlob(sessionPrefix + sessionFile2 + "/0", "text/log", new byte[0]); Blob blobU1V0 = mockBlob(userPrefix + userFile1 + "/0", "app/json", new byte[0]); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockSessionPage = mock(Page.class); + // Mocking generic Page class requires unchecked suppression. @SuppressWarnings("unchecked") Page mockUserPage = mock(Page.class); when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) @@ -363,6 +371,143 @@ public void saveAndReloadArtifact_savesAndReturnsFileData() { verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3})); } + @Test + public void save_noInlineData_throwsException() { + Part artifact = Part.builder().build(); // No inline data + assertThrows( + IllegalArgumentException.class, + () -> + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet()); + } + + @Test + public void save_storageException_throwsVerifyException() { + Part artifact = Part.fromBytes(new byte[] {1}, "text/plain"); + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + when(mockStorage.create(any(BlobInfo.class), any(byte[].class))) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet()); + } + + @Test + public void load_storageException_returnsEmpty() { + String blobNameV0 = String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + BlobId blobIdV0 = BlobId.of(BUCKET_NAME, blobNameV0); + when(mockStorage.get(blobIdV0)).thenThrow(new StorageException(500, "Induced error")); + + Optional loadedArtifact = + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.of(0))); + + assertThat(loadedArtifact).isEmpty(); + } + + @Test + public void list_sessionStorageException_throwsVerifyException() { + String sessionPrefix = String.format("%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID); + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> service.listArtifactKeys(APP_NAME, USER_ID, SESSION_ID).blockingGet()); + } + + @Test + public void list_userStorageException_throwsVerifyException() { + String sessionPrefix = String.format("%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID); + String userPrefix = String.format("%s/%s/user/", APP_NAME, USER_ID); + + // Mocking generic Page class requires unchecked suppression. + @SuppressWarnings("unchecked") + Page mockSessionPage = mock(Page.class); + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(sessionPrefix))) + .thenReturn(mockSessionPage); + when(mockSessionPage.iterateAll()).thenReturn(ImmutableList.of()); + + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(userPrefix))) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> service.listArtifactKeys(APP_NAME, USER_ID, SESSION_ID).blockingGet()); + } + + @Test + public void delete_storageException_throwsVerifyException() { + String blobNameV0 = String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1}); + + when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0)); + when(mockStorage.delete(ArgumentMatchers.>any())) + .thenThrow(new StorageException(500, "Induced error")); + + assertThrows( + VerifyException.class, + () -> service.deleteArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME).blockingAwait()); + } + + @Test + public void listVersions_storageException_returnsEmptyList() { + String prefix = String.format("%s/%s/%s/%s/", APP_NAME, USER_ID, SESSION_ID, FILENAME); + when(mockStorage.list(BUCKET_NAME, BlobListOption.prefix(prefix))) + .thenThrow(new StorageException(500, "Induced error")); + + ImmutableList versions = + service.listVersions(APP_NAME, USER_ID, SESSION_ID, FILENAME).blockingGet(); + + assertThat(versions).isEmpty(); + } + + @Test + public void saveAndReload_noContentTypeAnywhere_defaultsToOctetStream() { + // Artifact with no mime type + Part artifact = + Part.builder() + .inlineData(com.google.genai.types.Blob.builder().data(new byte[] {1}).build()) + .build(); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mock(Blob.class); + when(savedBlob.getName()).thenReturn(expectedBlobName); + when(savedBlob.getBucket()).thenReturn(BUCKET_NAME); + when(savedBlob.getContentType()).thenReturn(null); + when(mockStorage.create(any(BlobInfo.class), any(byte[].class))).thenReturn(savedBlob); + + Part result = + service + .saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact) + .blockingGet(); + + assertThat(result.fileData().get().mimeType()).hasValue("application/octet-stream"); + } + + @Test + public void saveAndReload_blobMissingContentType_usesArtifactContentType() { + Part artifact = Part.fromBytes(new byte[] {1}, "application/pdf"); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mock(Blob.class); + when(savedBlob.getName()).thenReturn(expectedBlobName); + when(savedBlob.getBucket()).thenReturn(BUCKET_NAME); + when(savedBlob.getContentType()).thenReturn(null); + when(mockStorage.create(any(BlobInfo.class), any(byte[].class))).thenReturn(savedBlob); + + Part result = + service + .saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact) + .blockingGet(); + + assertThat(result.fileData().get().mimeType()).hasValue("application/pdf"); + } + private static Optional asOptional(Maybe maybe) { return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 94cd399df..28123bab8 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -17,12 +17,14 @@ package com.google.adk.events; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; @@ -130,4 +132,37 @@ public void jsonSerialization_works() throws Exception { assertThat(deserialized).isEqualTo(eventActions); assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); } + + @Test + @SuppressWarnings("unchecked") // the nested map is known to be Map + public void merge_deeplyMergesStateDelta() { + EventActions eventActions1 = EventActions.builder().build(); + eventActions1.stateDelta().put("a", 1); + eventActions1.stateDelta().put("b", ImmutableMap.of("nested1", 10, "nested2", 20)); + eventActions1.stateDelta().put("c", 100); + EventActions eventActions2 = EventActions.builder().build(); + eventActions2.stateDelta().put("a", 2); + eventActions2.stateDelta().put("b", ImmutableMap.of("nested2", 22, "nested3", 30)); + eventActions2.stateDelta().put("d", 200); + + EventActions merged = eventActions1.toBuilder().merge(eventActions2).build(); + + assertThat(merged.stateDelta().keySet()).containsExactly("a", "b", "c", "d"); + assertThat(merged.stateDelta()).containsEntry("a", 2); + assertThat((Map) merged.stateDelta().get("b")) + .containsExactly("nested1", 10, "nested2", 22, "nested3", 30); + assertThat(merged.stateDelta()).containsEntry("c", 100); + assertThat(merged.stateDelta()).containsEntry("d", 200); + } + + @Test + public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { + EventActions eventActions1 = EventActions.builder().build(); + eventActions1.stateDelta().put("nested", ImmutableMap.of("a", 1)); + EventActions eventActions2 = EventActions.builder().build(); + eventActions2.stateDelta().put("nested", ImmutableMap.of(1, 2)); + + assertThrows( + IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 657d1c670..ff151a0b2 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -17,6 +17,7 @@ package com.google.adk.flows.llmflows; import static com.google.adk.testing.TestUtils.assertEqualIgnoringFunctionIds; +import static com.google.adk.testing.TestUtils.createGenerateContentResponseUsageMetadata; import static com.google.adk.testing.TestUtils.createInvocationContext; import static com.google.adk.testing.TestUtils.createLlmResponse; import static com.google.adk.testing.TestUtils.createTestAgent; @@ -575,4 +576,32 @@ public Single> runAsync(Map args, ToolContex return Single.just(response); } } + + @Test + public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { + GenerateContentResponseUsageMetadata usageMetadata = + createGenerateContentResponseUsageMetadata().build(); + LlmResponse llmResponse = LlmResponse.builder().usageMetadata(usageMetadata).build(); + InvocationContext invocationContext = + createInvocationContext(createTestAgent(createTestLlm(llmResponse))); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + Event baseEvent = + Event.builder() + .invocationId(invocationContext.invocationId()) + .author(invocationContext.agent().name()) + .build(); + + List events = + baseLlmFlow + .postprocess(invocationContext, baseEvent, LlmRequest.builder().build(), llmResponse) + .toList() + .blockingGet(); + + assertThat(events).hasSize(1); + Event event = getOnlyElement(events); + assertThat(event.content()).isEmpty(); + assertThat(event.usageMetadata()).hasValue(usageMetadata); + assertThat(event.author()).isEqualTo(invocationContext.agent().name()); + assertThat(event.invocationId()).isEqualTo(invocationContext.invocationId()); + } } diff --git a/core/src/test/java/com/google/adk/models/GeminiUtilTest.java b/core/src/test/java/com/google/adk/models/GeminiUtilTest.java index 49e73511d..31cbe76de 100644 --- a/core/src/test/java/com/google/adk/models/GeminiUtilTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiUtilTest.java @@ -24,6 +24,8 @@ import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Part; import java.util.Arrays; @@ -451,6 +453,43 @@ public void prepareGenenerateContentRequest_emptyRequest_returnsRequestWithConti .inOrder(); } + @Test + public void removeClientFunctionCallId_stripsIds() { + Part partWithFunctionCall = + Part.builder() + .functionCall( + FunctionCall.builder() + .name("foo") + .id("id1") + .args(ImmutableMap.of("key", "value")) + .build()) + .build(); + Part partWithFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("bar") + .id("id2") + .response(ImmutableMap.of("key", "value")) + .build()) + .build(); + LlmRequest request = toRequest(partWithFunctionCall, partWithFunctionResponse); + + LlmRequest result = GeminiUtil.removeClientFunctionCallId(request); + + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0).parts()).isPresent(); + assertThat(result.contents().get(0).parts().get()).hasSize(2); + Part resultPart1 = result.contents().get(0).parts().get().get(0); + assertThat(resultPart1.functionCall()).isPresent(); + assertThat(resultPart1.functionCall().get().id()).isEmpty(); + assertThat(resultPart1.functionCall().get().name()).hasValue("foo"); + Part resultPart2 = result.contents().get(0).parts().get().get(1); + assertThat(resultPart2.functionResponse()).isPresent(); + assertThat(resultPart2.functionResponse().get().id()).isEmpty(); + assertThat(resultPart2.functionResponse().get().name()).hasValue("bar"); + } + private static Content toContent(Part... parts) { return Content.builder().parts(ImmutableList.copyOf(parts)).build(); } diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index df94b76b2..70ae14bf1 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -41,6 +41,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Single; @@ -253,6 +254,13 @@ public static LlmResponse createFunctionCallLlmResponse( return createLlmResponse(content); } + public static GenerateContentResponseUsageMetadata.Builder + createGenerateContentResponseUsageMetadata() { + return GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20); + } + public static class EchoTool extends BaseTool { public EchoTool() { super("echo_tool", "description"); diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index dde1d73ea..16418657d 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GoogleMaps; import com.google.genai.types.GoogleSearch; import com.google.genai.types.GoogleSearchRetrieval; import com.google.genai.types.Tool; @@ -207,4 +208,22 @@ public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() { assertThat(updatedLlmRequest.config().get().tools().get()) .containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build()); } + + @Test + public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { + GoogleMapsTool googleMapsTool = new GoogleMapsTool(); + LlmRequest llmRequest = + LlmRequest.builder() + .config(GenerateContentConfig.builder().build()) + .model("gemini-2") + .build(); + LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); + Completable unused = + googleMapsTool.processLlmRequest(llmRequestBuilder, /* toolContext= */ null); + LlmRequest updatedLlmRequest = llmRequestBuilder.build(); + assertThat(updatedLlmRequest.config()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools().get()) + .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); + } } diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerEnvironmentTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerEnvironmentTest.java new file mode 100644 index 000000000..ed22819ec --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerEnvironmentTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerEnvironment}. */ +@RunWith(JUnit4.class) +public final class ComputerEnvironmentTest { + + @Test + public void testEnumValues() { + assertThat(ComputerEnvironment.values()) + .asList() + .containsAtLeast( + ComputerEnvironment.ENVIRONMENT_UNSPECIFIED, ComputerEnvironment.ENVIRONMENT_BROWSER); + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerStateTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerStateTest.java new file mode 100644 index 000000000..736f9be0e --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerStateTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerState}. */ +@RunWith(JUnit4.class) +public final class ComputerStateTest { + + @Test + public void testBuilder() { + byte[] screenshot = new byte[] {1, 2, 3}; + String url = "https://google.com"; + ComputerState state = ComputerState.builder().screenshot(screenshot).url(url).build(); + + assertThat(state.screenshot()).isEqualTo(screenshot); + assertThat(state.url()).hasValue(url); + } + + @Test + public void testBuilder_noUrl() { + byte[] screenshot = new byte[] {1, 2, 3}; + ComputerState state = ComputerState.builder().screenshot(screenshot).build(); + + assertThat(state.screenshot()).isEqualTo(screenshot); + assertThat(state.url()).isEmpty(); + } + + @Test + public void testEqualsAndHashCode() { + byte[] screenshot1 = new byte[] {1, 2, 3}; + byte[] screenshot2 = new byte[] {1, 2, 3}; + byte[] screenshot3 = new byte[] {4, 5, 6}; + + ComputerState state1 = ComputerState.builder().screenshot(screenshot1).url("url1").build(); + ComputerState state2 = ComputerState.builder().screenshot(screenshot2).url("url1").build(); + ComputerState state3 = ComputerState.builder().screenshot(screenshot3).url("url1").build(); + ComputerState state4 = ComputerState.builder().screenshot(screenshot1).url("url2").build(); + + assertThat(state1).isEqualTo(state2); + assertThat(state1.hashCode()).isEqualTo(state2.hashCode()); + + assertThat(state1).isNotEqualTo(state3); + assertThat(state1).isNotEqualTo(state4); + } + + @Test + public void testScreenshotImmutability() { + byte[] screenshot = new byte[] {1, 2, 3}; + ComputerState state = ComputerState.builder().screenshot(screenshot).build(); + + // Modify original array + screenshot[0] = 9; + assertThat(state.screenshot()[0]).isEqualTo(1); + + // Modify returned array + state.screenshot()[0] = 9; + assertThat(state.screenshot()[0]).isEqualTo(1); + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java new file mode 100644 index 000000000..20fb146cf --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java @@ -0,0 +1,258 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.adk.tools.Annotations.Schema; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Method; +import java.util.Base64; +import java.util.Map; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerUseTool}. */ +@RunWith(JUnit4.class) +public final class ComputerUseToolTest { + + private LlmAgent agent; + private InMemorySessionService sessionService; + private ToolContext toolContext; + private ComputerMock computerMock; + + @Before + public void setUp() { + agent = LlmAgent.builder().name("test-agent").build(); + sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .agent(agent) + .session(session) + .sessionService(sessionService) + .invocationId("invocation-id") + .build(); + toolContext = ToolContext.builder(invocationContext).functionCallId("functionCallId").build(); + computerMock = new ComputerMock(); + } + + @Test + public void testNormalizeX() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + assertThat(tool.runAsync(ImmutableMap.of("x", 0, "y", 0), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(0); + + assertThat(tool.runAsync(ImmutableMap.of("x", 500, "y", 300), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(960); // 500/1000 * 1920 + + assertThat(tool.runAsync(ImmutableMap.of("x", 1000, "y", 300), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(1919); // Clamped + } + + @Test + public void testNormalizeY() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + assertThat(tool.runAsync(ImmutableMap.of("x", 0, "y", 500), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastY).isEqualTo(540); // 500/1000 * 1080 + } + + @Test + public void testNormalizeWithCustomVirtualScreenSize() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {2000, 2000}); + + assertThat(tool.runAsync(ImmutableMap.of("x", 1000, "y", 1000), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(960); // 1000/2000 * 1920 + assertThat(computerMock.lastY).isEqualTo(540); // 1000/2000 * 1080 + } + + @Test + public void testNormalizeDragAndDrop() throws NoSuchMethodException { + Method method = + ComputerMock.class.getMethod("dragAndDrop", int.class, int.class, int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = + tool.runAsync( + ImmutableMap.of("x", 100, "y", 200, "destination_x", 800, "destination_y", 600), + toolContext) + .blockingGet(); + assertThat(result).isNotNull(); + + assertThat(computerMock.lastX).isEqualTo(192); + assertThat(computerMock.lastY).isEqualTo(216); + assertThat(computerMock.lastDestX).isEqualTo(1536); + assertThat(computerMock.lastDestY).isEqualTo(648); + } + + @Test + public void testResultFormatting() throws NoSuchMethodException { + byte[] screenshot = new byte[] {1, 2, 3}; + computerMock.nextState = + ComputerState.builder() + .screenshot(screenshot) + .url(Optional.of("https://example.com")) + .build(); + + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = + tool.runAsync(ImmutableMap.of("x", 500, "y", 500), toolContext).blockingGet(); + assertThat(result).containsKey("image"); + Object imageData = result.get("image"); + assertThat(imageData).isInstanceOf(Map.class); + ((Map) imageData) + .forEach( + (key, value) -> { + assertThat(key).isInstanceOf(String.class); + assertThat(value).isInstanceOf(String.class); + }); + @SuppressWarnings("unchecked") // The types of the key and value are checked above. + Map imageMap = (Map) imageData; + assertThat(imageMap.get("mimetype")).isEqualTo("image/png"); + assertThat(imageMap.get("data")).isEqualTo(Base64.getEncoder().encodeToString(screenshot)); + assertThat(result.get("url")).isEqualTo("https://example.com"); + assertThat(result).containsKey("image"); + assertThat(result).doesNotContainKey("screenshot"); + } + + @Test + public void testResultFormatting_noScreenshot() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("noScreenshot"); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result).doesNotContainKey("image"); + assertThat(result.get("url")).isEqualTo("https://example.com"); + } + + @Test + public void testResultFormatting_nonByteArrayScreenshot() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("nonByteArrayScreenshot"); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result).doesNotContainKey("image"); + assertThat(result.get("screenshot")).isEqualTo("not-a-byte-array"); + } + + @Test + public void testNormalizeWithInvalidInputs() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + assertThrows( + IllegalArgumentException.class, + () -> tool.runAsync(ImmutableMap.of("x", "invalid", "y", 500), toolContext).blockingGet()); + } + + @Test + public void testRunAsyncWithNoCoordinates() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + // Arguments without x, y, etc. should be passed as is. + ImmutableMap args = ImmutableMap.of("other", "value"); + var unused = tool.runAsync(args, toolContext).blockingGet(); + assertThat(computerMock.lastX).isEqualTo(0); + assertThat(computerMock.lastY).isEqualTo(0); + } + + @Test + public void testCoordinateClamping() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + // Test clamping to 0 + var unused1 = tool.runAsync(ImmutableMap.of("x", -100, "y", -50), toolContext).blockingGet(); + assertThat(computerMock.lastX).isEqualTo(0); + assertThat(computerMock.lastY).isEqualTo(0); + + // Test clamping to max + var unused2 = tool.runAsync(ImmutableMap.of("x", 2000, "y", 1500), toolContext).blockingGet(); + assertThat(computerMock.lastX).isEqualTo(1919); + assertThat(computerMock.lastY).isEqualTo(1079); + } + + /** A mock class for Computer actions. */ + public static class ComputerMock { + public int lastX; + public int lastY; + public int lastDestX; + public int lastDestY; + public ComputerState nextState = + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build(); + + public Single clickAt(@Schema(name = "x") int x, @Schema(name = "y") int y) { + this.lastX = x; + this.lastY = y; + return Single.just(nextState); + } + + public Single dragAndDrop( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "destination_x") int destinationX, + @Schema(name = "destination_y") int destinationY) { + this.lastX = x; + this.lastY = y; + this.lastDestX = destinationX; + this.lastDestY = destinationY; + return Single.just(nextState); + } + + public Single> noScreenshot() { + return Single.just(ImmutableMap.of("url", "https://example.com")); + } + + public Single> nonByteArrayScreenshot() { + return Single.just(ImmutableMap.of("screenshot", "not-a-byte-array")); + } + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java new file mode 100644 index 000000000..1ed49419e --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java @@ -0,0 +1,264 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.genai.types.Environment; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Tool; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerUseToolset}. */ +@RunWith(JUnit4.class) +public final class ComputerUseToolsetTest { + + private LlmAgent agent; + private InMemorySessionService sessionService; + private ToolContext toolContext; + private MockComputer mockComputer; + private ComputerUseToolset toolset; + + @Before + public void setUp() { + agent = LlmAgent.builder().name("test-agent").build(); + sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .agent(agent) + .session(session) + .sessionService(sessionService) + .invocationId("invocation-id") + .build(); + toolContext = ToolContext.builder(invocationContext).functionCallId("functionCallId").build(); + + mockComputer = new MockComputer(); + toolset = new ComputerUseToolset(mockComputer); + } + + @Test + public void testGetTools() { + List tools = toolset.getTools(null).toList().blockingGet(); + + assertThat(mockComputer.initializeCallCount).isEqualTo(1); + assertThat(tools).isNotEmpty(); + + // Verify method filtering + assertThat(tools.stream().anyMatch(t -> t.name().equals("clickAt"))).isTrue(); + assertThat(tools.stream().noneMatch(t -> t.name().equals("screenSize"))).isTrue(); + assertThat(tools.stream().noneMatch(t -> t.name().equals("environment"))).isTrue(); + } + + @Test + public void testEnsureInitializedOnlyCalledOnce() { + var unused1 = toolset.getTools(null).toList().blockingGet(); + var unused2 = toolset.getTools(null).toList().blockingGet(); + + assertThat(mockComputer.initializeCallCount).isEqualTo(1); + } + + @Test + public void testGetTools_cachesTools() { + List tools1 = toolset.getTools(null).toList().blockingGet(); + List tools2 = toolset.getTools(null).toList().blockingGet(); + + assertThat(tools1).hasSize(tools2.size()); + for (int i = 0; i < tools1.size(); i++) { + assertThat(tools1.get(i)).isSameInstanceAs(tools2.get(i)); + } + } + + @Test + public void testProcessLlmRequest() { + LlmRequest.Builder builder = + LlmRequest.builder().model("test-model").config(GenerateContentConfig.builder().build()); + + toolset.processLlmRequest(builder, toolContext).blockingAwait(); + + LlmRequest request = builder.build(); + assertThat(request.config()).isPresent(); + GenerateContentConfig config = request.config().get(); + + assertThat(config.tools()).isPresent(); + List tools = config.tools().get(); + + // Find the computer use tool + Optional computerUseTool = + tools.stream().filter(t -> t.computerUse().isPresent()).findFirst(); + assertThat(computerUseTool).isPresent(); + assertThat(computerUseTool.get().computerUse().get().environment().get().knownEnum()) + .isEqualTo(Environment.Known.ENVIRONMENT_BROWSER); + + // Verify computer actions were added as function declarations + Optional functionTool = + tools.stream().filter(t -> t.functionDeclarations().isPresent()).findFirst(); + assertThat(functionTool).isPresent(); + assertThat( + functionTool.get().functionDeclarations().get().stream() + .anyMatch(fd -> fd.name().orElse("").equals("clickAt"))) + .isTrue(); + } + + @Test + public void testProcessLlmRequest_withComputerError() { + mockComputer.nextError = new RuntimeException("Computer failure"); + LlmRequest.Builder builder = + LlmRequest.builder().model("test-model").config(GenerateContentConfig.builder().build()); + + assertThrows( + RuntimeException.class, + () -> toolset.processLlmRequest(builder, toolContext).blockingAwait()); + } + + private static class MockComputer implements BaseComputer { + int initializeCallCount = 0; + Throwable nextError = null; + + @Override + public Completable initialize() { + if (nextError != null) { + return Completable.error(nextError); + } + this.initializeCallCount++; + return Completable.complete(); + } + + @Override + public Single screenSize() { + if (nextError != null) { + return Single.error(nextError); + } + return Single.just(new int[] {1920, 1080}); + } + + @Override + public Single environment() { + if (nextError != null) { + return Single.error(nextError); + } + return Single.just(ComputerEnvironment.ENVIRONMENT_BROWSER); + } + + @Override + public Single openWebBrowser() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single clickAt(int x, int y) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single hoverAt(int x, int y) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single typeTextAt( + int x, int y, String text, Boolean pressEnter, Boolean clearBeforeTyping) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single scrollDocument(String direction) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single scrollAt(int x, int y, String direction, int magnitude) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single wait(Duration duration) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single goBack() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single goForward() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single search() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single navigate(String url) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.of(url)).build()); + } + + @Override + public Single keyCombination(List keys) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single dragAndDrop(int x, int y, int destinationX, int destinationY) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single currentState() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Completable close() { + return Completable.complete(); + } + } +}