Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion src/__tests__/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import { JsonBlock, type TextBlock, type ToolResultBlock } from '../types/messag
import type { AgentData } from '../types/agent.js'
import type { ToolContext } from '../tools/tool.js'

vi.mock('@modelcontextprotocol/sdk/types.js', () => ({
ElicitRequestSchema: { method: 'elicitation/create' },
}))

/**
* Helper to create a mock async generator that yields a result message.
* This simulates the behavior of callToolStream returning a stream that ends with a result.
Expand All @@ -23,6 +27,7 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
connect: vi.fn(),
close: vi.fn(),
listTools: vi.fn(),
setRequestHandler: vi.fn(),
experimental: {
tasks: {
callToolStream: vi.fn(),
Expand Down Expand Up @@ -84,7 +89,7 @@ describe('MCP Integration', () => {
})

it('initializes SDK client with correct configuration', () => {
expect(Client).toHaveBeenCalledWith({ name: 'TestApp', version: '0.0.1' })
expect(Client).toHaveBeenCalledWith({ name: 'TestApp', version: '0.0.1' }, undefined)
})

it('manages connection state lazily', async () => {
Expand Down Expand Up @@ -139,6 +144,120 @@ describe('MCP Integration', () => {
expect(sdkClientMock.close).toHaveBeenCalled()
expect(mockTransport.close).toHaveBeenCalled()
})

describe('elicitation callback', () => {
it('registers callback when provided', async () => {
const callback = vi.fn()
const clientWithCallback = new McpClient({
applicationName: 'TestApp',
transport: mockTransport,
elicitationCallback: callback,
})
const sdkClient = vi.mocked(Client).mock.results[1]!.value

await clientWithCallback.connect()

expect(sdkClient.setRequestHandler).toHaveBeenCalled()
})

it('does not register callback when not provided', async () => {
await client.connect()

expect(sdkClientMock.setRequestHandler).not.toHaveBeenCalled()
})

it('invokes callback and returns all action types correctly', async () => {
const callback = vi.fn()
const clientWithCallback = new McpClient({
applicationName: 'TestApp',
transport: mockTransport,
elicitationCallback: callback,
})
const sdkClient = vi.mocked(Client).mock.results[vi.mocked(Client).mock.results.length - 1]!.value

await clientWithCallback.connect()

const handler = sdkClient.setRequestHandler.mock.calls[0]![1]
const mockExtra = { sessionId: 'test-session' }

// Test accept action (form mode)
callback.mockResolvedValueOnce({ action: 'accept', content: { response: 'yes' } })
const acceptResult = await handler(
{
params: {
message: 'Do you want to continue?',
requestedSchema: { type: 'object' },
},
},
mockExtra
)

expect(callback).toHaveBeenCalledWith(mockExtra, {
message: 'Do you want to continue?',
requestedSchema: { type: 'object' },
})
expect(acceptResult).toEqual({
action: 'accept',
content: { response: 'yes' },
})

// Test decline action
callback.mockResolvedValueOnce({ action: 'decline' })
const declineResult = await handler({ params: { message: 'Proceed?' } }, mockExtra)
expect(declineResult).toEqual({ action: 'decline', content: undefined })

// Test cancel action
callback.mockResolvedValueOnce({ action: 'cancel' })
const cancelResult = await handler({ params: { message: 'Cancel operation?' } }, mockExtra)
expect(cancelResult).toEqual({ action: 'cancel', content: undefined })

// Test URL mode
callback.mockResolvedValueOnce({ action: 'accept' })
const urlResult = await handler(
{
params: {
mode: 'url',
message: 'Please authorize via OAuth',
url: 'https://example.com/oauth',
elicitationId: 'elicit-123',
},
},
mockExtra
)

expect(callback).toHaveBeenCalledWith(mockExtra, {
mode: 'url',
message: 'Please authorize via OAuth',
url: 'https://example.com/oauth',
elicitationId: 'elicit-123',
})
expect(urlResult).toEqual({
action: 'accept',
content: undefined,
})
})

it('handles callback errors gracefully', async () => {
const callback = vi.fn().mockRejectedValue(new Error('User cancelled'))

const clientWithCallback = new McpClient({
applicationName: 'TestApp',
transport: mockTransport,
elicitationCallback: callback,
})
const sdkClient = vi.mocked(Client).mock.results[vi.mocked(Client).mock.results.length - 1]!.value

await clientWithCallback.connect()

const handler = sdkClient.setRequestHandler.mock.calls[0]![1]
const mockRequest = {
params: { message: 'Continue?' },
}
const mockExtra = { sessionId: 'test-session' }

await expect(handler(mockRequest, mockExtra)).rejects.toThrow('User cancelled')
})
})
})

describe('McpTool', () => {
Expand Down
3 changes: 3 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,6 @@ export type { Logger } from './logging/types.js'

// MCP Client types and implementations
export { type McpClientConfig, McpClient } from './mcp.js'

// Elicitation types
export type { ElicitationCallback } from './types/elicitation.js'
44 changes: 39 additions & 5 deletions src/mcp.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { ElicitRequestSchema } from '@modelcontextprotocol/sdk/types.js'
import { takeResult } from '@modelcontextprotocol/sdk/shared/responseMessage.js'
import type { JSONSchema, JSONValue } from './types/json.js'
import type { ElicitationCallback } from './types/elicitation.js'
import { McpTool } from './tools/mcp-tool.js'

/** Temporary placeholder for RuntimeConfig */
Expand All @@ -11,7 +13,10 @@ export interface RuntimeConfig {
}

/** Arguments for configuring an MCP Client. */
export type McpClientConfig = RuntimeConfig & { transport: Transport }
export type McpClientConfig = RuntimeConfig & {
transport: Transport
elicitationCallback?: ElicitationCallback
}

/** MCP Client for interacting with Model Context Protocol servers. */
export class McpClient {
Expand All @@ -20,16 +25,33 @@ export class McpClient {
private _transport: Transport
private _connected: boolean
private _client: Client
private _elicitationCallback?: ElicitationCallback

constructor(args: McpClientConfig) {
this._clientName = args.applicationName || 'strands-agents-ts-sdk'
this._clientVersion = args.applicationVersion || '0.0.1'
this._transport = args.transport
this._connected = false
this._client = new Client({
name: this._clientName,
version: this._clientVersion,
})

if (args.elicitationCallback !== undefined) {
this._elicitationCallback = args.elicitationCallback
}

const clientOptions = this._elicitationCallback
? {
capabilities: {
elicitation: { form: {}, url: {} },
},
}
: undefined

this._client = new Client(
{
name: this._clientName,
version: this._clientVersion,
},
clientOptions
)
}

get client(): Client {
Expand All @@ -55,6 +77,18 @@ export class McpClient {

await this._client.connect(this._transport)

if (this._elicitationCallback) {
const callback = this._elicitationCallback
this._client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
const result = await callback(extra, request.params)

return {
action: result.action,
content: result.content,
}
})
}

this._connected = true
}

Expand Down
33 changes: 33 additions & 0 deletions src/types/elicitation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import type {
ElicitResult,
ElicitRequestParams,
ClientRequest,
ClientNotification,
} from '@modelcontextprotocol/sdk/types.js'
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js'

/**
* Context provided to the elicitation callback from the MCP SDK.
*/
type ElicitationContext = RequestHandlerExtra<ClientRequest, ClientNotification>

/**
* Callback function invoked when an MCP server requests additional input during tool execution.
*
* @param context - Context information about the elicitation request
* @param params - Parameters including the message and optional schema
* @returns A promise that resolves with the user's response
*
* @example
* ```typescript
* const elicitationCallback: ElicitationCallback = async (_context, params) => {
* console.log(`Server is asking: ${params.message}`)
* const userInput = await getUserInput()
* return {
* action: 'accept',
* content: { response: userInput }
* }
* }
* ```
*/
export type ElicitationCallback = (context: ElicitationContext, params: ElicitRequestParams) => Promise<ElicitResult>
59 changes: 59 additions & 0 deletions test/integ/__fixtures__/test-mcp-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ function createTestServer(): McpServer {
{
capabilities: {
tools: {},
elicitation: { form: {} },
},
}
)
Expand Down Expand Up @@ -125,6 +126,63 @@ function createTestServer(): McpServer {
}
)

// Register elicitation tool
server.registerTool(
'confirm_action',
{
title: 'Confirm Action Tool',
description: 'Requests user confirmation before performing an action',
inputSchema: {
action: z.string(),
},
outputSchema: {
confirmed: z.boolean(),
action: z.string(),
},
},
async ({ action }) => {
// Request user confirmation via elicitation
const result = await server.server.elicitInput({
message: `Do you want to proceed with: ${action}?`,
requestedSchema: {
type: 'object',
properties: {
confirmed: {
type: 'boolean',
title: 'Confirm action',
description: 'Confirm whether to proceed',
},
},
required: ['confirmed'],
},
})

if (result.action === 'accept' && result.content?.confirmed) {
const output = { confirmed: true, action }
return {
content: [
{
type: 'text',
text: `Action confirmed: ${action}`,
},
],
structuredContent: output,
}
}

const output = { confirmed: false, action }
return {
content: [
{
type: 'text',
text: `Action declined: ${action}`,
},
],
structuredContent: output,
}
}
)

return server
}

Expand Down Expand Up @@ -164,6 +222,7 @@ export async function startHTTPServer(): Promise<HttpServerInfo> {
// Create a new transport for each request (stateless mode)
const transport = new StreamableHTTPServerTransport({
enableJsonResponse: true,
sessionIdGenerator: undefined,
})

res.on('close', async () => {
Expand Down
2 changes: 1 addition & 1 deletion test/integ/__fixtures__/test-mcp-task-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,13 @@ export async function startTaskHTTPServer(): Promise<TaskHttpServerInfo> {
const mcpServer = createTaskTestServer(taskStore)
const transport = new StreamableHTTPServerTransport({
enableJsonResponse: true,
sessionIdGenerator: undefined,
})

res.on('close', async () => {
await transport.close()
})

// @ts-expect-error - MCP SDK doesn't support exactOptionalPropertyTypes
await mcpServer.connect(transport)
await transport.handleRequest(req, res, parsedBody)
} catch (error) {
Expand Down
Loading
Loading