Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions tool/mcptoolset/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,39 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) {
}

var adkTools []tool.Tool
reconnected := false

cursor := ""
for {
resp, err := session.ListTools(ctx, &mcp.ListToolsParams{
Cursor: cursor,
})
if err != nil {
return nil, fmt.Errorf("failed to list MCP tools: %w", err)
if reconnected {
return nil, fmt.Errorf("failed to list MCP tools after reconnection: %w", err)
}
// On any error, attempt to refresh the connection.
// refreshConnection uses ping to verify if reconnection is actually needed.
var refreshErr error
session, refreshErr = s.refreshConnection(ctx)
if refreshErr != nil {
return nil, fmt.Errorf("failed to list MCP tools: %w", err)
}
reconnected = true
// Per MCP spec, cursors should not persist across sessions.
// Start listing from scratch after reconnection.
cursor = ""
adkTools = nil
resp, err = session.ListTools(ctx, &mcp.ListToolsParams{
Cursor: cursor,
})
if err != nil {
return nil, fmt.Errorf("failed to list MCP tools: %w", err)
}
}

for _, mcpTool := range resp.Tools {
t, err := convertTool(mcpTool, s.getSession)
t, err := convertTool(mcpTool, s)
if err != nil {
return nil, fmt.Errorf("failed to convert MCP tool %q to adk tool: %w", mcpTool.Name, err)
}
Expand Down Expand Up @@ -150,3 +171,26 @@ func (s *set) getSession(ctx context.Context) (*mcp.ClientSession, error) {
s.session = session
return s.session, nil
}

func (s *set) refreshConnection(ctx context.Context) (*mcp.ClientSession, error) {
s.mu.Lock()
defer s.mu.Unlock()

// First, try ping to confirm connection is dead
if s.session != nil {
if err := s.session.Ping(ctx, &mcp.PingParams{}); err == nil {
// Connection is actually alive, don't refresh
return s.session, nil
}
s.session.Close()
s.session = nil
}

session, err := s.client.Connect(ctx, s.transport, nil)
if err != nil {
return nil, fmt.Errorf("failed to refresh MCP session: %w", err)
}

s.session = session
return s.session, nil
}
107 changes: 107 additions & 0 deletions tool/mcptoolset/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
icontext "google.golang.org/adk/internal/context"
"google.golang.org/adk/internal/httprr"
"google.golang.org/adk/internal/testutil"
"google.golang.org/adk/internal/toolinternal"
"google.golang.org/adk/model"
"google.golang.org/adk/model/gemini"
"google.golang.org/adk/runner"
Expand Down Expand Up @@ -307,3 +308,109 @@ func TestToolFilter(t *testing.T) {
t.Errorf("tools mismatch (-want +got):\n%s", diff)
}
}

func TestListToolsReconnection(t *testing.T) {
server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "get_weather", Description: "returns weather in the given city"}, weatherFunc)

rt := &reconnectableTransport{server: server}
spyTransport := &spyTransport{Transport: rt}

ts, err := mcptoolset.New(mcptoolset.Config{
Transport: spyTransport,
})
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}

ctx := icontext.NewReadonlyContext(icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}))

// First call to Tools should create a session.
_, err = ts.Tools(ctx)
if err != nil {
t.Fatalf("First Tools call failed: %v", err)
}

// Kill the transport by closing the connection.
spyTransport.lastConn.Close()

// Second call should detect the closed connection and reconnect.
_, err = ts.Tools(ctx)
if err != nil {
t.Fatalf("Second Tools call failed: %v", err)
}

// Verify that we reconnected (should have 2 connections).
if spyTransport.connectCount != 2 {
t.Errorf("Expected 2 Connect calls (reconnect after close), got %d", spyTransport.connectCount)
}
}

func TestCallToolReconnection(t *testing.T) {
server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil)
mcp.AddTool(server, &mcp.Tool{Name: "get_weather", Description: "returns weather in the given city"}, weatherFunc)

rt := &reconnectableTransport{server: server}
spyTransport := &spyTransport{Transport: rt}

ts, err := mcptoolset.New(mcptoolset.Config{
Transport: spyTransport,
})
if err != nil {
t.Fatalf("Failed to create MCP tool set: %v", err)
}

invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})
ctx := icontext.NewReadonlyContext(invCtx)
toolCtx := toolinternal.NewToolContext(invCtx, "", nil)

// Get tools first to establish a session.
tools, err := ts.Tools(ctx)
if err != nil {
t.Fatalf("Tools call failed: %v", err)
}

// Kill the transport by closing the connection.
spyTransport.lastConn.Close()

// Call the tool - should reconnect and succeed.
fnTool := tools[0].(toolinternal.FunctionTool)
result, err := fnTool.Run(toolCtx, map[string]any{"city": "Paris"})
if err != nil {
t.Fatalf("Tool call after reconnect failed: %v", err)
}
if result == nil {
t.Fatal("Expected non-nil result after reconnect")
}

// Verify that we reconnected (should have 2 connections).
if spyTransport.connectCount != 2 {
t.Errorf("Expected 2 Connect calls (reconnect after close), got %d", spyTransport.connectCount)
}
}

type spyTransport struct {
mcp.Transport
connectCount int
lastConn mcp.Connection
}

func (t *spyTransport) Connect(ctx context.Context) (mcp.Connection, error) {
t.connectCount++
conn, err := t.Transport.Connect(ctx)
t.lastConn = conn
return conn, err
}

type reconnectableTransport struct {
server *mcp.Server
}

func (rt *reconnectableTransport) Connect(ctx context.Context) (mcp.Connection, error) {
ct, st := mcp.NewInMemoryTransports()
_, err := rt.server.Connect(context.Background(), st, nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to use the ctx from the function arguments here instead of context.Background(). This ensures that if the client-side connection attempt is cancelled, the server-side connection setup is also cancelled. This prevents potential resource leaks in the test and is a better practice for context propagation.

Suggested change
_, err := rt.server.Connect(context.Background(), st, nil)
_, err := rt.server.Connect(ctx, st, nil)

if err != nil {
return nil, err
}
return ct.Connect(ctx)
}
23 changes: 16 additions & 7 deletions tool/mcptoolset/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package mcptoolset

import (
"context"
"errors"
"fmt"
"strings"
Expand All @@ -29,17 +28,15 @@ import (
"google.golang.org/adk/tool"
)

type getSessionFunc func(ctx context.Context) (*mcp.ClientSession, error)

func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc) (tool.Tool, error) {
func convertTool(t *mcp.Tool, s *set) (tool.Tool, error) {
mcp := &mcpTool{
name: t.Name,
description: t.Description,
funcDeclaration: &genai.FunctionDeclaration{
Name: t.Name,
Description: t.Description,
},
getSessionFunc: getSessionFunc,
set: s,
}

// Since t.InputSchema and t.OutputSchema are pointers (*jsonschema.Schema) and the destination ResponseJsonSchema
Expand All @@ -61,7 +58,7 @@ type mcpTool struct {
description string
funcDeclaration *genai.FunctionDeclaration

getSessionFunc getSessionFunc
set *set
}

// Name implements the tool.Tool.
Expand All @@ -88,7 +85,7 @@ func (t *mcpTool) Declaration() *genai.FunctionDeclaration {
}

func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) {
session, err := t.getSessionFunc(ctx)
session, err := t.set.getSession(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
Expand All @@ -98,6 +95,18 @@ func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) {
Name: t.name,
Arguments: args,
})

if err != nil {
// On any error, attempt to refresh the connection.
// refreshConnection uses ping to verify if reconnection is actually needed.
if session, refreshErr := t.set.refreshConnection(ctx); refreshErr == nil {
res, err = session.CallTool(ctx, &mcp.CallToolParams{
Name: t.name,
Arguments: args,
})
}
}

if err != nil {
return nil, fmt.Errorf("failed to call MCP tool %q with err: %w", t.name, err)
}
Expand Down