From 89ae4206f174cf4fe842e66fbbe8dee5113e623f Mon Sep 17 00:00:00 2001 From: westerberg Date: Wed, 3 Dec 2025 13:49:01 +0000 Subject: [PATCH 1/2] Modify handleFunctionCalls to execute beforeToolCallbacks despite tool existing. Add retry callback example. --- agent/llmagent/llmagent.go | 7 +- examples/tools/retrytool/main.go | 114 +++++++++++++++++ internal/llminternal/base_flow.go | 43 +++---- internal/llminternal/base_flow_test.go | 170 ++++++++++++++++++++++++- 4 files changed, 306 insertions(+), 28 deletions(-) create mode 100644 examples/tools/retrytool/main.go diff --git a/agent/llmagent/llmagent.go b/agent/llmagent/llmagent.go index aee848ba..c15cdd47 100644 --- a/agent/llmagent/llmagent.go +++ b/agent/llmagent/llmagent.go @@ -262,6 +262,7 @@ type AfterModelCallback func(ctx agent.CallbackContext, llmResponse *model.LLMRe // Parameters: // - ctx: The tool.Context for the current tool execution. // - tool: The tool.Tool instance that is about to be executed. +// This value can be nil if the tool name invoked by the LLM function call does not correspond to an existing tool. // - args: The original arguments provided to the tool. type BeforeToolCallback func(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) @@ -271,9 +272,11 @@ type BeforeToolCallback func(ctx tool.Context, tool tool.Tool, args map[string]a // Parameters: // - ctx: The tool.Context for the tool execution. // - tool: The tool.Tool instance that was executed. +// This value can be nil if the tool name invoked by the LLM function call does not correspond to an existing tool, +// and a BeforeToolCallback generated result or an error // - args: The arguments originally passed to the tool. -// - result: The result returned by the tool's Run method. -// - err: The error returned by the tool's Run method. +// - result: The result returned by the tool's Run method or by a BeforeToolCallback. +// - err: The error returned by the tool's Run method or by a BeforeToolCallback. type AfterToolCallback func(ctx tool.Context, tool tool.Tool, args, result map[string]any, err error) (map[string]any, error) // IncludeContents controls what parts of prior conversation history is received by llmagent. diff --git a/examples/tools/retrytool/main.go b/examples/tools/retrytool/main.go new file mode 100644 index 00000000..24b7e605 --- /dev/null +++ b/examples/tools/retrytool/main.go @@ -0,0 +1,114 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package demonstrates a workaround for using Google Search tool with other tools. +package main + +import ( + "context" + "fmt" + "log" + "os" + + "google.golang.org/genai" + + "google.golang.org/adk/agent" + "google.golang.org/adk/agent/llmagent" + "google.golang.org/adk/cmd/launcher" + "google.golang.org/adk/cmd/launcher/full" + "google.golang.org/adk/model/gemini" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/functiontool" +) + +// Package main demonstrates a workaround for using multiple tool types (e.g., +// Google Search and custom functions) in a single agent. This is necessary +// due to limitations in the genai API. The approach is to wrap agents with +// different tool types into sub-agents, which are then managed by a root agent. +func main() { + ctx := context.Background() + + gemminiModel, err := gemini.NewModel(ctx, "gemini-2.5-flash", &genai.ClientConfig{ + APIKey: os.Getenv("GOOGLE_API_KEY"), + }) + if err != nil { + log.Fatalf("Failed to create model: %v", err) + } + + type Input struct { + Id int `json:"id"` + } + type Output struct { + Text string `json:"text"` + } + handler := func(ctx tool.Context, input Input) (Output, error) { + var text string + switch input.Id { + case 1: + text = "path/to/build/1\n" + case 2: + text = "different/path/to/build/2\n" + default: + text = fmt.Sprintf("path not found for build %d\n", input.Id) + } + return Output{ + Text: text, + }, nil + } + getCloudBuild, err := functiontool.New(functiontool.Config{ + Name: "get_cloud_build", + Description: "Returns the cloud build", + }, handler) + if err != nil { + log.Fatalf("Failed to create tool: %v", err) + } + + a, err := llmagent.New(llmagent.Config{ + Name: "root_agent", + Model: gemminiModel, + Description: "You can get a cloud build.", + Instruction: `You can get information about a cloud build. + Use the available tools for that but instead of 'get_cloud_build' call 'get_cloud-_build', it's ok if it does not exist the client will know what to do with it. + Never confirm with the user in case of a retry.`, + Tools: []tool.Tool{ + getCloudBuild, + }, + BeforeToolCallbacks: []llmagent.BeforeToolCallback{ + func(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + toolNames := []string{getCloudBuild.Name()} + if tool == nil { + fmt.Print("beforeToolCallbacks called with empty tool\n") + return map[string]any{ + "error": fmt.Sprintf("Invalid function name. Correct names are :%v. Please retry with a new function call without asking the user.", toolNames), + }, nil + } else { + fmt.Printf("beforeToolCallbacks called with %s tool\n", tool.Name()) + } + return nil, nil + }, + }, + }) + if err != nil { + log.Fatalf("Failed to create agent: %v", err) + } + + config := &launcher.Config{ + AgentLoader: agent.NewSingleLoader(a), + } + + l := full.NewLauncher() + if err = l.Execute(ctx, config, os.Args[1:]); err != nil { + log.Fatalf("Run failed: %v\n\n%s", err, l.CommandLineSyntax()) + } +} diff --git a/internal/llminternal/base_flow.go b/internal/llminternal/base_flow.go index 15937e23..8b90406a 100644 --- a/internal/llminternal/base_flow.go +++ b/internal/llminternal/base_flow.go @@ -372,19 +372,25 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st fnCalls := utils.FunctionCalls(resp.Content) for _, fnCall := range fnCalls { - curTool, ok := toolsDict[fnCall.Name] - if !ok { - return nil, fmt.Errorf("unknown tool: %q", fnCall.Name) - } - funcTool, ok := curTool.(toolinternal.FunctionTool) - if !ok { - return nil, fmt.Errorf("tool %q is not a function tool", curTool.Name()) - } + curTool, toolExists := toolsDict[fnCall.Name] + funcTool, toolIsFunction := curTool.(toolinternal.FunctionTool) toolCtx := toolinternal.NewToolContext(ctx, fnCall.ID, &session.EventActions{StateDelta: make(map[string]any)}) - // toolCtx := tool. spans := telemetry.StartTrace(ctx, "execute_tool "+fnCall.Name) - result := f.callTool(funcTool, fnCall.Args, toolCtx) + result, err := f.invokeBeforeToolCallbacks(funcTool, fnCall.Args, toolCtx) + if result == nil && err == nil { + if !toolExists { + return nil, fmt.Errorf("unknown tool: %q", fnCall.Name) + } + if !toolIsFunction { + return nil, fmt.Errorf("tool %q is not a function tool", curTool.Name()) + } + result, err = funcTool.Run(toolCtx, fnCall.Args) + } + result, err = f.invokeAfterToolCallbacks(funcTool, fnCall.Args, toolCtx, result, err) + if err != nil { + result = map[string]any{"error": err.Error()} + } // TODO: agent.canonical_after_tool_callbacks // TODO: handle long-running tool. @@ -406,7 +412,10 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st ev.Author = ctx.Agent().Name() ev.Branch = ctx.Branch() ev.Actions = *toolCtx.Actions() - telemetry.TraceToolCall(spans, curTool, fnCall.Args, ev) + // TODO trace else for events created by callbacks + if curTool != nil { + telemetry.TraceToolCall(spans, curTool, fnCall.Args, ev) + } fnResponseEvents = append(fnResponseEvents, ev) } mergedEvent, err := mergeParallelFunctionResponseEvents(fnResponseEvents) @@ -419,18 +428,6 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st return mergedEvent, nil } -func (f *Flow) callTool(tool toolinternal.FunctionTool, fArgs map[string]any, toolCtx tool.Context) map[string]any { - result, err := f.invokeBeforeToolCallbacks(tool, fArgs, toolCtx) - if result == nil && err == nil { - result, err = tool.Run(toolCtx, fArgs) - } - result, err = f.invokeAfterToolCallbacks(tool, fArgs, toolCtx, result, err) - if err != nil { - return map[string]any{"error": err.Error()} - } - return result -} - func (f *Flow) invokeBeforeToolCallbacks(tool toolinternal.FunctionTool, fArgs map[string]any, toolCtx tool.Context) (map[string]any, error) { for _, callback := range f.BeforeToolCallbacks { result, err := callback(toolCtx, tool, fArgs) diff --git a/internal/llminternal/base_flow_test.go b/internal/llminternal/base_flow_test.go index 65651560..dab5a21d 100644 --- a/internal/llminternal/base_flow_test.go +++ b/internal/llminternal/base_flow_test.go @@ -21,6 +21,8 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/genai" + "google.golang.org/adk/agent" + icontext "google.golang.org/adk/internal/context" "google.golang.org/adk/internal/toolinternal" "google.golang.org/adk/model" "google.golang.org/adk/tool" @@ -66,7 +68,7 @@ func (m *mockFunctionTool) Declaration() *genai.FunctionDeclaration { return nil } -func TestCallTool(t *testing.T) { +func TestHandleFunctionCalls(t *testing.T) { tests := []struct { name string tool toolinternal.FunctionTool @@ -259,6 +261,8 @@ func TestCallTool(t *testing.T) { }, } + simpleAgent, _ := agent.New(agent.Config{Name: "agentName"}) + ctx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{Agent: simpleAgent}) for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { f := &Flow{ @@ -266,8 +270,168 @@ func TestCallTool(t *testing.T) { AfterToolCallbacks: tc.afterToolCallbacks, } - got := f.callTool(tc.tool, tc.args, nil) - if diff := cmp.Diff(tc.want, got); diff != "" { + toolsDict := map[string]tool.Tool{tc.tool.Name(): tc.tool} + resp := &model.LLMResponse{ + Content: genai.NewContentFromFunctionCall(tc.tool.Name(), tc.args, genai.RoleModel), + } + got, err := f.handleFunctionCalls(ctx, toolsDict, resp) + if err != nil { + t.Fatalf("encountered unnexpected error: %s", err) + } + if got.Content == nil || len(got.Content.Parts) == 0 || got.Content.Parts[0].FunctionResponse == nil { + t.Errorf("invalid function call result") + return + } + gotFunctionResponse := got.Content.Parts[0].FunctionResponse.Response + if diff := cmp.Diff(tc.want, gotFunctionResponse); diff != "" { + t.Errorf("callTool() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleWrongNameFunctionCalls(t *testing.T) { + tests := []struct { + name string + toolName string + tool toolinternal.FunctionTool + args map[string]any + beforeToolCallbacks []BeforeToolCallback + afterToolCallbacks []AfterToolCallback + want map[string]any + wantErr bool + wantErrMessage string + }{ + { + name: "tool run fails with wrong name", + toolName: "wrongTool", + tool: &mockFunctionTool{ + name: "testTool", + runFunc: func(ctx tool.Context, args map[string]any) (map[string]any, error) { + return map[string]any{"result": "success"}, nil + }, + }, + args: map[string]any{"key": "value"}, + wantErr: true, + wantErrMessage: `unknown tool: "wrongTool"`, + }, + { + name: "tool fails with wrong name and beforeToolCallback without modify", + toolName: "wrongTool", + tool: &mockFunctionTool{ + name: "testTool", + runFunc: func(ctx tool.Context, args map[string]any) (map[string]any, error) { + return map[string]any{"result": "success"}, nil + }, + }, + args: map[string]any{"key": "value"}, + beforeToolCallbacks: []BeforeToolCallback{ + func(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + return nil, nil + }, + }, + wantErr: true, + wantErrMessage: `unknown tool: "wrongTool"`, + }, + { + name: "tool doesn't fails with wrong name and beforeToolCallback", + toolName: "wrongTool", + tool: &mockFunctionTool{ + name: "testTool", + runFunc: func(ctx tool.Context, args map[string]any) (map[string]any, error) { + return map[string]any{"result": "success"}, nil + }, + }, + args: map[string]any{"key": "value"}, + beforeToolCallbacks: []BeforeToolCallback{ + func(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + if tool == nil { + return map[string]any{"result": "from_before"}, nil + } + return nil, nil + }, + }, + want: map[string]any{"result": "from_before"}, + wantErr: false, + }, + { + name: "tool doesn't fails with wrong name and beforeToolCallback returning error", + toolName: "wrongTool", + tool: &mockFunctionTool{ + name: "testTool", + runFunc: func(ctx tool.Context, args map[string]any) (map[string]any, error) { + return map[string]any{"result": "success"}, nil + }, + }, + args: map[string]any{"key": "value"}, + beforeToolCallbacks: []BeforeToolCallback{ + func(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + if tool == nil { + return nil, errors.New("tool is nil") + } + return nil, nil + }, + }, + want: map[string]any{"error": "tool is nil"}, + wantErr: false, + }, + { + name: "tool doesn't fails with wrong name and beforeToolCallback, with afterToolCallback also being called", + toolName: "wrongTool", + tool: &mockFunctionTool{ + name: "testTool", + runFunc: func(ctx tool.Context, args map[string]any) (map[string]any, error) { + return map[string]any{"result": "success"}, nil + }, + }, + args: map[string]any{"key": "value"}, + beforeToolCallbacks: []BeforeToolCallback{ + func(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { + if tool == nil { + return map[string]any{"result": "from_before"}, nil + } + return nil, nil + }, + }, + afterToolCallbacks: []AfterToolCallback{ + func(ctx tool.Context, tool tool.Tool, args, result map[string]any, err error) (map[string]any, error) { + return map[string]any{"result": "modified"}, nil + }, + }, + want: map[string]any{"result": "modified"}, + wantErr: false, + }, + } + + simpleAgent, _ := agent.New(agent.Config{Name: "agentName"}) + ctx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{Agent: simpleAgent}) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := &Flow{ + BeforeToolCallbacks: tc.beforeToolCallbacks, + AfterToolCallbacks: tc.afterToolCallbacks, + } + + toolsDict := map[string]tool.Tool{tc.tool.Name(): tc.tool} + resp := &model.LLMResponse{ + Content: genai.NewContentFromFunctionCall(tc.toolName, tc.args, genai.RoleModel), + } + got, err := f.handleFunctionCalls(ctx, toolsDict, resp) + if err != nil && !tc.wantErr { + t.Fatalf("encountered unnexpected error: %s", err) + } + if err != nil && tc.wantErr { + if diff := cmp.Diff(tc.wantErrMessage, err.Error()); diff != "" { + t.Errorf("callTool() mismatch error message (-want +got):\n%s", diff) + } + return + } + if got.Content == nil || len(got.Content.Parts) == 0 || got.Content.Parts[0].FunctionResponse == nil { + t.Errorf("invalid function call result") + return + } + gotFunctionResponse := got.Content.Parts[0].FunctionResponse.Response + if diff := cmp.Diff(tc.want, gotFunctionResponse); diff != "" { t.Errorf("callTool() mismatch (-want +got):\n%s", diff) } }) From 4b7642b8662c367f7d3215133f857483929f385b Mon Sep 17 00:00:00 2001 From: westerberg Date: Wed, 3 Dec 2025 14:25:00 +0000 Subject: [PATCH 2/2] Fix typos --- examples/tools/retrytool/main.go | 4 ++-- internal/llminternal/base_flow_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/tools/retrytool/main.go b/examples/tools/retrytool/main.go index 24b7e605..577d6925 100644 --- a/examples/tools/retrytool/main.go +++ b/examples/tools/retrytool/main.go @@ -39,7 +39,7 @@ import ( func main() { ctx := context.Background() - gemminiModel, err := gemini.NewModel(ctx, "gemini-2.5-flash", &genai.ClientConfig{ + geminiModel, err := gemini.NewModel(ctx, "gemini-2.5-flash", &genai.ClientConfig{ APIKey: os.Getenv("GOOGLE_API_KEY"), }) if err != nil { @@ -76,7 +76,7 @@ func main() { a, err := llmagent.New(llmagent.Config{ Name: "root_agent", - Model: gemminiModel, + Model: geminiModel, Description: "You can get a cloud build.", Instruction: `You can get information about a cloud build. Use the available tools for that but instead of 'get_cloud_build' call 'get_cloud-_build', it's ok if it does not exist the client will know what to do with it. diff --git a/internal/llminternal/base_flow_test.go b/internal/llminternal/base_flow_test.go index dab5a21d..667a0a49 100644 --- a/internal/llminternal/base_flow_test.go +++ b/internal/llminternal/base_flow_test.go @@ -276,7 +276,7 @@ func TestHandleFunctionCalls(t *testing.T) { } got, err := f.handleFunctionCalls(ctx, toolsDict, resp) if err != nil { - t.Fatalf("encountered unnexpected error: %s", err) + t.Fatalf("encountered unexpected error: %s", err) } if got.Content == nil || len(got.Content.Parts) == 0 || got.Content.Parts[0].FunctionResponse == nil { t.Errorf("invalid function call result") @@ -418,7 +418,7 @@ func TestHandleWrongNameFunctionCalls(t *testing.T) { } got, err := f.handleFunctionCalls(ctx, toolsDict, resp) if err != nil && !tc.wantErr { - t.Fatalf("encountered unnexpected error: %s", err) + t.Fatalf("encountered unexpected error: %s", err) } if err != nil && tc.wantErr { if diff := cmp.Diff(tc.wantErrMessage, err.Error()); diff != "" {