diff --git a/tool/mcptoolset/set.go b/tool/mcptoolset/set.go index e27ecdb6..d517e039 100644 --- a/tool/mcptoolset/set.go +++ b/tool/mcptoolset/set.go @@ -102,6 +102,7 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) { } var adkTools []tool.Tool + reconnected := false cursor := "" for { @@ -109,11 +110,31 @@ func (s *set) Tools(ctx agent.ReadonlyContext) ([]tool.Tool, error) { Cursor: cursor, }) if err != nil { - return nil, fmt.Errorf("failed to list MCP tools: %w", err) + if reconnected { + return nil, fmt.Errorf("failed to list MCP tools after reconnection: %w", err) + } + // On any error, attempt to refresh the connection. + // refreshConnection uses ping to verify if reconnection is actually needed. + var refreshErr error + session, refreshErr = s.refreshConnection(ctx) + if refreshErr != nil { + return nil, fmt.Errorf("failed to list MCP tools: %w", err) + } + reconnected = true + // Per MCP spec, cursors should not persist across sessions. + // Start listing from scratch after reconnection. + cursor = "" + adkTools = nil + resp, err = session.ListTools(ctx, &mcp.ListToolsParams{ + Cursor: cursor, + }) + if err != nil { + return nil, fmt.Errorf("failed to list MCP tools: %w", err) + } } for _, mcpTool := range resp.Tools { - t, err := convertTool(mcpTool, s.getSession) + t, err := convertTool(mcpTool, s) if err != nil { return nil, fmt.Errorf("failed to convert MCP tool %q to adk tool: %w", mcpTool.Name, err) } @@ -150,3 +171,26 @@ func (s *set) getSession(ctx context.Context) (*mcp.ClientSession, error) { s.session = session return s.session, nil } + +func (s *set) refreshConnection(ctx context.Context) (*mcp.ClientSession, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // First, try ping to confirm connection is dead + if s.session != nil { + if err := s.session.Ping(ctx, &mcp.PingParams{}); err == nil { + // Connection is actually alive, don't refresh + return s.session, nil + } + s.session.Close() + s.session = nil + } + + session, err := s.client.Connect(ctx, s.transport, nil) + if err != nil { + return nil, fmt.Errorf("failed to refresh MCP session: %w", err) + } + + s.session = session + return s.session, nil +} diff --git a/tool/mcptoolset/set_test.go b/tool/mcptoolset/set_test.go index 87f14b7c..f1cad1c1 100644 --- a/tool/mcptoolset/set_test.go +++ b/tool/mcptoolset/set_test.go @@ -34,6 +34,7 @@ import ( icontext "google.golang.org/adk/internal/context" "google.golang.org/adk/internal/httprr" "google.golang.org/adk/internal/testutil" + "google.golang.org/adk/internal/toolinternal" "google.golang.org/adk/model" "google.golang.org/adk/model/gemini" "google.golang.org/adk/runner" @@ -307,3 +308,109 @@ func TestToolFilter(t *testing.T) { t.Errorf("tools mismatch (-want +got):\n%s", diff) } } + +func TestListToolsReconnection(t *testing.T) { + server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "get_weather", Description: "returns weather in the given city"}, weatherFunc) + + rt := &reconnectableTransport{server: server} + spyTransport := &spyTransport{Transport: rt} + + ts, err := mcptoolset.New(mcptoolset.Config{ + Transport: spyTransport, + }) + if err != nil { + t.Fatalf("Failed to create MCP tool set: %v", err) + } + + ctx := icontext.NewReadonlyContext(icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{})) + + // First call to Tools should create a session. + _, err = ts.Tools(ctx) + if err != nil { + t.Fatalf("First Tools call failed: %v", err) + } + + // Kill the transport by closing the connection. + spyTransport.lastConn.Close() + + // Second call should detect the closed connection and reconnect. + _, err = ts.Tools(ctx) + if err != nil { + t.Fatalf("Second Tools call failed: %v", err) + } + + // Verify that we reconnected (should have 2 connections). + if spyTransport.connectCount != 2 { + t.Errorf("Expected 2 Connect calls (reconnect after close), got %d", spyTransport.connectCount) + } +} + +func TestCallToolReconnection(t *testing.T) { + server := mcp.NewServer(&mcp.Implementation{Name: "test_server", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "get_weather", Description: "returns weather in the given city"}, weatherFunc) + + rt := &reconnectableTransport{server: server} + spyTransport := &spyTransport{Transport: rt} + + ts, err := mcptoolset.New(mcptoolset.Config{ + Transport: spyTransport, + }) + if err != nil { + t.Fatalf("Failed to create MCP tool set: %v", err) + } + + invCtx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{}) + ctx := icontext.NewReadonlyContext(invCtx) + toolCtx := toolinternal.NewToolContext(invCtx, "", nil) + + // Get tools first to establish a session. + tools, err := ts.Tools(ctx) + if err != nil { + t.Fatalf("Tools call failed: %v", err) + } + + // Kill the transport by closing the connection. + spyTransport.lastConn.Close() + + // Call the tool - should reconnect and succeed. + fnTool := tools[0].(toolinternal.FunctionTool) + result, err := fnTool.Run(toolCtx, map[string]any{"city": "Paris"}) + if err != nil { + t.Fatalf("Tool call after reconnect failed: %v", err) + } + if result == nil { + t.Fatal("Expected non-nil result after reconnect") + } + + // Verify that we reconnected (should have 2 connections). + if spyTransport.connectCount != 2 { + t.Errorf("Expected 2 Connect calls (reconnect after close), got %d", spyTransport.connectCount) + } +} + +type spyTransport struct { + mcp.Transport + connectCount int + lastConn mcp.Connection +} + +func (t *spyTransport) Connect(ctx context.Context) (mcp.Connection, error) { + t.connectCount++ + conn, err := t.Transport.Connect(ctx) + t.lastConn = conn + return conn, err +} + +type reconnectableTransport struct { + server *mcp.Server +} + +func (rt *reconnectableTransport) Connect(ctx context.Context) (mcp.Connection, error) { + ct, st := mcp.NewInMemoryTransports() + _, err := rt.server.Connect(context.Background(), st, nil) + if err != nil { + return nil, err + } + return ct.Connect(ctx) +} diff --git a/tool/mcptoolset/tool.go b/tool/mcptoolset/tool.go index ee2354f0..f8964452 100644 --- a/tool/mcptoolset/tool.go +++ b/tool/mcptoolset/tool.go @@ -15,7 +15,6 @@ package mcptoolset import ( - "context" "errors" "fmt" "strings" @@ -29,9 +28,7 @@ import ( "google.golang.org/adk/tool" ) -type getSessionFunc func(ctx context.Context) (*mcp.ClientSession, error) - -func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc) (tool.Tool, error) { +func convertTool(t *mcp.Tool, s *set) (tool.Tool, error) { mcp := &mcpTool{ name: t.Name, description: t.Description, @@ -39,7 +36,7 @@ func convertTool(t *mcp.Tool, getSessionFunc getSessionFunc) (tool.Tool, error) Name: t.Name, Description: t.Description, }, - getSessionFunc: getSessionFunc, + set: s, } // Since t.InputSchema and t.OutputSchema are pointers (*jsonschema.Schema) and the destination ResponseJsonSchema @@ -61,7 +58,7 @@ type mcpTool struct { description string funcDeclaration *genai.FunctionDeclaration - getSessionFunc getSessionFunc + set *set } // Name implements the tool.Tool. @@ -88,7 +85,7 @@ func (t *mcpTool) Declaration() *genai.FunctionDeclaration { } func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) { - session, err := t.getSessionFunc(ctx) + session, err := t.set.getSession(ctx) if err != nil { return nil, fmt.Errorf("failed to get session: %w", err) } @@ -98,6 +95,18 @@ func (t *mcpTool) Run(ctx tool.Context, args any) (map[string]any, error) { Name: t.name, Arguments: args, }) + + if err != nil { + // On any error, attempt to refresh the connection. + // refreshConnection uses ping to verify if reconnection is actually needed. + if session, refreshErr := t.set.refreshConnection(ctx); refreshErr == nil { + res, err = session.CallTool(ctx, &mcp.CallToolParams{ + Name: t.name, + Arguments: args, + }) + } + } + if err != nil { return nil, fmt.Errorf("failed to call MCP tool %q with err: %w", t.name, err) }