diff --git a/src/agent/__tests__/agent.test.ts b/src/agent/__tests__/agent.test.ts index 1b2f0681..e081aba2 100644 --- a/src/agent/__tests__/agent.test.ts +++ b/src/agent/__tests__/agent.test.ts @@ -3,6 +3,7 @@ import { Agent, type ToolList } from '../agent.js' import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' import { createMockTool, createRandomTool } from '../../__fixtures__/tool-helpers.js' +import { MockHookProvider } from '../../__fixtures__/mock-hook-provider.js' import { ConcurrentInvocationError } from '../../errors.js' import { MaxTokensError, @@ -17,6 +18,8 @@ import { ImageBlock, VideoBlock, DocumentBlock, + type ToolContext, + type Tool, } from '../../index.js' import { AgentPrinter } from '../printer.js' import { BeforeInvocationEvent, BeforeToolsEvent } from '../../hooks/events.js' @@ -878,4 +881,116 @@ describe('Agent', () => { }) }) }) + + describe('invocationState', () => { + it('passes invocationState to BeforeInvocationEvent hook', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + + const hookProvider = new MockHookProvider({ includeModelEvents: false }) + + const agent = new Agent({ + model, + printer: false, + hooks: [hookProvider], + }) + + const testState = { userId: '123', sessionId: 'abc' } + await agent.invoke('Hello', { invocationState: testState }) + + const beforeEvent = hookProvider.invocations.find((e) => e instanceof BeforeInvocationEvent) + expect(beforeEvent).toBeDefined() + expect((beforeEvent as BeforeInvocationEvent)?.invocationState).toEqual(testState) + }) + + it('passes invocationState to tool context', async () => { + let capturedContext: ToolContext | undefined + + const mockTool: Tool = { + name: 'testTool', + description: 'A test tool', + toolSpec: { + name: 'testTool', + description: 'A test tool', + inputSchema: { type: 'object', properties: {} }, + }, + stream: async function* (context: ToolContext) { + capturedContext = context + yield { type: 'toolStreamEvent' as const, data: 'processing' } + return new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('done')], + }) + }, + } + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done!' }) + + const agent = new Agent({ + model, + tools: [mockTool], + printer: false, + }) + + const testState = { customData: 'value' } + await agent.invoke('Use tool', { invocationState: testState }) + + expect(capturedContext).toBeDefined() + expect(capturedContext?.invocationState).toEqual(testState) + }) + + it('handles undefined invocationState', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ + model, + printer: false, + }) + + const result = await agent.invoke('Hello') + expect(result.stopReason).toBe('endTurn') + }) + + it('invocationState is not persisted between invocations', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First response' }) + .addTurn({ type: 'textBlock', text: 'Second response' }) + + const hookProvider = new MockHookProvider({ includeModelEvents: false }) + + const agent = new Agent({ + model, + printer: false, + hooks: [hookProvider], + }) + + await agent.invoke('First', { invocationState: { request: 1 } }) + await agent.invoke('Second') // No invocationState + + const events = hookProvider.invocations.filter((e) => e instanceof BeforeInvocationEvent) as BeforeInvocationEvent[] + expect(events).toHaveLength(2) + expect(events[0]?.invocationState).toEqual({ request: 1 }) + expect(events[1]?.invocationState).toBeUndefined() + }) + + it('stream method accepts invocationState option', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + + const hookProvider = new MockHookProvider({ includeModelEvents: false }) + + const agent = new Agent({ + model, + printer: false, + hooks: [hookProvider], + }) + + const testState = { streamTest: true } + const { result } = await collectGenerator(agent.stream('Hello', { invocationState: testState })) + + const beforeEvent = hookProvider.invocations.find((e) => e instanceof BeforeInvocationEvent) as BeforeInvocationEvent | undefined + expect(beforeEvent?.invocationState).toEqual(testState) + expect(result.stopReason).toBe('endTurn') + }) + }) }) diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 4f00483b..332604be 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -115,6 +115,17 @@ export type AgentConfig = { */ export type InvokeArgs = string | ContentBlock[] | ContentBlockData[] | Message[] | MessageData[] +/** + * Options for invoking an agent. + */ +export type InvokeOptions = { + /** + * Per-invocation state passed to hooks and tools. + * Not persisted in agent state between invocations. + */ + invocationState?: Record +} + /** * Orchestrates the interaction between a model, a set of tools, and MCP clients. * The Agent is responsible for managing the lifecycle of tools and clients @@ -250,17 +261,20 @@ export class Agent implements AgentData { * streaming events. * * @param args - Arguments for invoking the agent + * @param options - Optional invocation options including invocationState * @returns Promise that resolves to the final AgentResult * * @example * ```typescript * const agent = new Agent({ model, tools }) - * const result = await agent.invoke('What is 2 + 2?') + * const result = await agent.invoke('What is 2 + 2?', { + * invocationState: { userId: '123', sessionId: 'abc' } + * }) * console.log(result.lastMessage) // Agent's response * ``` */ - public async invoke(args: InvokeArgs): Promise { - const gen = this.stream(args) + public async invoke(args: InvokeArgs, options?: InvokeOptions): Promise { + const gen = this.stream(args, options) let result = await gen.next() while (!result.done) { result = await gen.next() @@ -285,25 +299,28 @@ export class Agent implements AgentData { * with valid toolResponses * * @param args - Arguments for invoking the agent + * @param options - Optional invocation options including invocationState * @returns Async generator that yields AgentStreamEvent objects and returns AgentResult * * @example * ```typescript * const agent = new Agent({ model, tools }) * - * for await (const event of agent.stream('Hello')) { + * for await (const event of agent.stream('Hello', { + * invocationState: { userId: '123' } + * })) { * console.log('Event:', event.type) * } * // Messages array is mutated in place and contains the full conversation * ``` */ - public async *stream(args: InvokeArgs): AsyncGenerator { + public async *stream(args: InvokeArgs, options?: InvokeOptions): AsyncGenerator { using _lock = this.acquireLock() await this.initialize() // Delegate to _stream and process events through printer and hooks - const streamGenerator = this._stream(args) + const streamGenerator = this._stream(args, options?.invocationState) let result = await streamGenerator.next() while (!result.done) { @@ -330,18 +347,26 @@ export class Agent implements AgentData { * Separated to centralize printer event processing in the public stream method. * * @param args - Arguments for invoking the agent + * @param invocationState - Per-invocation state passed to hooks and tools * @returns Async generator that yields AgentStreamEvent objects and returns AgentResult */ - private async *_stream(args: InvokeArgs): AsyncGenerator { + private async *_stream( + args: InvokeArgs, + invocationState?: Record + ): AsyncGenerator { let currentArgs: InvokeArgs | undefined = args // Emit event before the loop starts - yield new BeforeInvocationEvent({ agent: this }) + const eventData: { agent: AgentData; invocationState?: Record } = { agent: this } + if (invocationState !== undefined) { + eventData.invocationState = invocationState + } + yield new BeforeInvocationEvent(eventData) try { // Main agent loop - continues until model stops without requesting tools while (true) { - const modelResult = yield* this.invokeModel(currentArgs) + const modelResult = yield* this.invokeModel(currentArgs, invocationState) currentArgs = undefined // Only pass args on first invocation if (modelResult.stopReason !== 'toolUse') { // Loop terminates - no tool use requested @@ -354,7 +379,7 @@ export class Agent implements AgentData { } // Execute tools sequentially - const toolResultMessage = yield* this.executeTools(modelResult.message, this._toolRegistry) + const toolResultMessage = yield* this.executeTools(modelResult.message, this._toolRegistry, invocationState) // Add assistant message with tool uses right before adding tool results // This ensures we don't have dangling tool use messages if tool execution fails @@ -427,10 +452,12 @@ export class Agent implements AgentData { * Invokes the model provider and streams all events. * * @param args - Optional arguments for invoking the model + * @param invocationState - Per-invocation state passed to hooks * @returns Object containing the assistant message and stop reason */ private async *invokeModel( - args?: InvokeArgs + args?: InvokeArgs, + invocationState?: Record ): AsyncGenerator { // Normalize input and append messages to conversation const messagesToAppend = this._normalizeInput(args) @@ -444,26 +471,44 @@ export class Agent implements AgentData { streamOptions.systemPrompt = this.systemPrompt } - yield new BeforeModelCallEvent({ agent: this }) + const modelEventData: { agent: AgentData; invocationState?: Record } = { agent: this } + if (invocationState !== undefined) { + modelEventData.invocationState = invocationState + } + yield new BeforeModelCallEvent(modelEventData) try { const { message, stopReason } = yield* this._streamFromModel(this.messages, streamOptions) - yield new AfterModelCallEvent({ agent: this, stopData: { message, stopReason } }) + const afterModelData: { agent: AgentData; stopData: { message: Message; stopReason: string }; invocationState?: Record } = { + agent: this as AgentData, + stopData: { message, stopReason }, + } + if (invocationState !== undefined) { + afterModelData.invocationState = invocationState + } + yield new AfterModelCallEvent(afterModelData) return { message, stopReason } } catch (error) { const modelError = normalizeError(error) // Create error event - const errorEvent = new AfterModelCallEvent({ agent: this, error: modelError }) + const errorEventData: { agent: AgentData; error: Error; invocationState?: Record } = { + agent: this as AgentData, + error: modelError, + } + if (invocationState !== undefined) { + errorEventData.invocationState = invocationState + } + const errorEvent = new AfterModelCallEvent(errorEventData) // Yield error event - stream will invoke hooks yield errorEvent // After yielding, hooks have been invoked and may have set retryModelCall if (errorEvent.retryModelCall) { - return yield* this.invokeModel(args) + return yield* this.invokeModel(args, invocationState) } // Re-throw error @@ -505,11 +550,13 @@ export class Agent implements AgentData { * * @param assistantMessage - The assistant message containing tool use blocks * @param toolRegistry - Registry containing available tools + * @param invocationState - Per-invocation state passed to tool contexts * @returns User message containing tool results */ private async *executeTools( assistantMessage: Message, - toolRegistry: ToolRegistry + toolRegistry: ToolRegistry, + invocationState?: Record ): AsyncGenerator { yield new BeforeToolsEvent({ agent: this, message: assistantMessage }) @@ -526,7 +573,7 @@ export class Agent implements AgentData { const toolResultBlocks: ToolResultBlock[] = [] for (const toolUseBlock of toolUseBlocks) { - const toolResultBlock = yield* this.executeTool(toolUseBlock, toolRegistry) + const toolResultBlock = yield* this.executeTool(toolUseBlock, toolRegistry, invocationState) toolResultBlocks.push(toolResultBlock) // Yield the tool result block as it's created @@ -552,11 +599,13 @@ export class Agent implements AgentData { * * @param toolUseBlock - Tool use block to execute * @param toolRegistry - Registry containing available tools + * @param invocationState - Per-invocation state passed to tool context * @returns Tool result block */ private async *executeTool( toolUseBlock: ToolUseBlock, - toolRegistry: ToolRegistry + toolRegistry: ToolRegistry, + invocationState?: Record ): AsyncGenerator { const tool = toolRegistry.find((t) => t.name === toolUseBlock.name) @@ -567,7 +616,15 @@ export class Agent implements AgentData { input: toolUseBlock.input, } - yield new BeforeToolCallEvent({ agent: this, toolUse, tool }) + const beforeToolData: { agent: AgentData; toolUse: { name: string; toolUseId: string; input: JSONValue }; tool: Tool | undefined; invocationState?: Record } = { + agent: this as AgentData, + toolUse, + tool, + } + if (invocationState !== undefined) { + beforeToolData.invocationState = invocationState + } + yield new BeforeToolCallEvent(beforeToolData) if (!tool) { // Tool not found - return error result instead of throwing @@ -577,7 +634,16 @@ export class Agent implements AgentData { content: [new TextBlock(`Tool '${toolUseBlock.name}' not found in registry`)], }) - yield new AfterToolCallEvent({ agent: this, toolUse, tool, result: errorResult }) + const afterToolData: { agent: AgentData; toolUse: { name: string; toolUseId: string; input: JSONValue }; tool: undefined; result: ToolResultBlock; invocationState?: Record } = { + agent: this as AgentData, + toolUse, + tool, + result: errorResult, + } + if (invocationState !== undefined) { + afterToolData.invocationState = invocationState + } + yield new AfterToolCallEvent(afterToolData) return errorResult } @@ -591,6 +657,9 @@ export class Agent implements AgentData { }, agent: this, } + if (invocationState !== undefined) { + toolContext.invocationState = invocationState + } try { const toolGenerator = tool.stream(toolContext) @@ -606,12 +675,30 @@ export class Agent implements AgentData { content: [new TextBlock(`Tool '${toolUseBlock.name}' did not return a result`)], }) - yield new AfterToolCallEvent({ agent: this, toolUse, tool, result: errorResult }) + const afterToolData: { agent: AgentData; toolUse: { name: string; toolUseId: string; input: JSONValue }; tool: Tool; result: ToolResultBlock; invocationState?: Record } = { + agent: this as AgentData, + toolUse, + tool, + result: errorResult, + } + if (invocationState !== undefined) { + afterToolData.invocationState = invocationState + } + yield new AfterToolCallEvent(afterToolData) return errorResult } - yield new AfterToolCallEvent({ agent: this, toolUse, tool, result: toolResult }) + const afterSuccessData: { agent: AgentData; toolUse: { name: string; toolUseId: string; input: JSONValue }; tool: Tool; result: ToolResultBlock; invocationState?: Record } = { + agent: this as AgentData, + toolUse, + tool, + result: toolResult, + } + if (invocationState !== undefined) { + afterSuccessData.invocationState = invocationState + } + yield new AfterToolCallEvent(afterSuccessData) // Tool already returns ToolResultBlock directly return toolResult @@ -625,7 +712,17 @@ export class Agent implements AgentData { error: toolError, }) - yield new AfterToolCallEvent({ agent: this, toolUse, tool, result: errorResult, error: toolError }) + const afterErrorData: { agent: AgentData; toolUse: { name: string; toolUseId: string; input: JSONValue }; tool: Tool; result: ToolResultBlock; error: Error; invocationState?: Record } = { + agent: this as AgentData, + toolUse, + tool, + result: errorResult, + error: toolError, + } + if (invocationState !== undefined) { + afterErrorData.invocationState = invocationState + } + yield new AfterToolCallEvent(afterErrorData) return errorResult } diff --git a/src/hooks/__tests__/events.test.ts b/src/hooks/__tests__/events.test.ts index c516e7f0..75f8d31d 100644 --- a/src/hooks/__tests__/events.test.ts +++ b/src/hooks/__tests__/events.test.ts @@ -28,6 +28,18 @@ describe('BeforeInvocationEvent', () => { event.agent = new Agent() }) + it('creates instance with invocationState', () => { + const agent = new Agent() + const invocationState = { userId: '123', sessionId: 'abc' } + const event = new BeforeInvocationEvent({ agent, invocationState }) + + expect(event).toEqual({ + type: 'beforeInvocationEvent', + agent: agent, + invocationState: { userId: '123', sessionId: 'abc' }, + }) + }) + it('returns false for _shouldReverseCallbacks', () => { const agent = new Agent() const event = new BeforeInvocationEvent({ agent }) diff --git a/src/hooks/events.ts b/src/hooks/events.ts index 2fd7ab83..e4f262c7 100644 --- a/src/hooks/events.ts +++ b/src/hooks/events.ts @@ -26,10 +26,14 @@ export abstract class HookEvent { export class BeforeInvocationEvent extends HookEvent { readonly type = 'beforeInvocationEvent' as const readonly agent: AgentData + readonly invocationState?: Record - constructor(data: { agent: AgentData }) { + constructor(data: { agent: AgentData; invocationState?: Record }) { super() this.agent = data.agent + if (data.invocationState !== undefined) { + this.invocationState = data.invocationState + } } } @@ -82,16 +86,21 @@ export class BeforeToolCallEvent extends HookEvent { input: JSONValue } readonly tool: Tool | undefined + readonly invocationState?: Record constructor(data: { agent: AgentData toolUse: { name: string; toolUseId: string; input: JSONValue } tool: Tool | undefined + invocationState?: Record }) { super() this.agent = data.agent this.toolUse = data.toolUse this.tool = data.tool + if (data.invocationState !== undefined) { + this.invocationState = data.invocationState + } } } @@ -111,6 +120,7 @@ export class AfterToolCallEvent extends HookEvent { readonly tool: Tool | undefined readonly result: ToolResultBlock readonly error?: Error + readonly invocationState?: Record constructor(data: { agent: AgentData @@ -118,6 +128,7 @@ export class AfterToolCallEvent extends HookEvent { tool: Tool | undefined result: ToolResultBlock error?: Error + invocationState?: Record }) { super() this.agent = data.agent @@ -127,6 +138,9 @@ export class AfterToolCallEvent extends HookEvent { if (data.error !== undefined) { this.error = data.error } + if (data.invocationState !== undefined) { + this.invocationState = data.invocationState + } } override _shouldReverseCallbacks(): boolean { @@ -141,10 +155,14 @@ export class AfterToolCallEvent extends HookEvent { export class BeforeModelCallEvent extends HookEvent { readonly type = 'beforeModelCallEvent' as const readonly agent: AgentData + readonly invocationState?: Record - constructor(data: { agent: AgentData }) { + constructor(data: { agent: AgentData; invocationState?: Record }) { super() this.agent = data.agent + if (data.invocationState !== undefined) { + this.invocationState = data.invocationState + } } } @@ -174,6 +192,7 @@ export class AfterModelCallEvent extends HookEvent { readonly agent: AgentData readonly stopData?: ModelStopData readonly error?: Error + readonly invocationState?: Record /** * Optional flag that can be set by hook callbacks to request a retry of the model call. @@ -182,7 +201,7 @@ export class AfterModelCallEvent extends HookEvent { */ retryModelCall?: boolean - constructor(data: { agent: AgentData; stopData?: ModelStopData; error?: Error }) { + constructor(data: { agent: AgentData; stopData?: ModelStopData; error?: Error; invocationState?: Record }) { super() this.agent = data.agent if (data.stopData !== undefined) { @@ -191,6 +210,9 @@ export class AfterModelCallEvent extends HookEvent { if (data.error !== undefined) { this.error = data.error } + if (data.invocationState !== undefined) { + this.invocationState = data.invocationState + } } override _shouldReverseCallbacks(): boolean { diff --git a/src/index.ts b/src/index.ts index bf01dcf5..ce698a8c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,7 +14,7 @@ export type { AgentState } from './agent/state.js' // Agent types export type { AgentData } from './types/agent.js' export { AgentResult } from './types/agent.js' -export type { AgentConfig, ToolList } from './agent/agent.js' +export type { AgentConfig, ToolList, InvokeOptions } from './agent/agent.js' // Error types export { diff --git a/src/tools/tool.ts b/src/tools/tool.ts index cb76edb3..74ae9464 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -21,6 +21,12 @@ export interface ToolContext { * Provides access to agent state and other agent-level information. */ agent: AgentData + + /** + * Per-invocation state passed from the invoke call. + * Not persisted in agent state between invocations. + */ + invocationState?: Record } /** diff --git a/test/integ/agent.test.ts b/test/integ/agent.test.ts index d5af2c25..67ad9643 100644 --- a/test/integ/agent.test.ts +++ b/test/integ/agent.test.ts @@ -1,8 +1,19 @@ import { describe, expect, it } from 'vitest' -import { Agent, DocumentBlock, ImageBlock, Message, TextBlock, tool } from '@strands-agents/sdk' +import { + Agent, + DocumentBlock, + ImageBlock, + Message, + TextBlock, + tool, + type ToolContext, + type HookProvider, +} from '@strands-agents/sdk' import { notebook } from '@strands-agents/sdk/vended_tools/notebook' import { httpRequest } from '@strands-agents/sdk/vended_tools/http_request' import { z } from 'zod' +import type { HookRegistry } from '../../src/hooks/index.js' +import { BeforeInvocationEvent, BeforeToolCallEvent } from '../../src/hooks/events.js' import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' import { loadFixture } from './__fixtures__/test-helpers.js' @@ -242,5 +253,74 @@ describe.each(allProviders)('Agent with $name', ({ name, skip, createModel }) => ) ).toBe(true) }) + + describe('invocationState', () => { + it('flows invocationState from invoke through hooks to tools', async () => { + const events: string[] = [] + const invocationState = { requestId: 'req-123', user: 'testUser' } + + class TestHookProvider implements HookProvider { + registerCallbacks(registry: HookRegistry): void { + registry.addCallback(BeforeInvocationEvent, async (event: BeforeInvocationEvent) => { + events.push(`beforeInvoke:${JSON.stringify(event.invocationState)}`) + }) + registry.addCallback(BeforeToolCallEvent, async () => { + events.push('beforeToolCall') + }) + } + } + + let toolReceivedState: Record | undefined + + const echoTool = tool({ + name: 'echo', + description: 'Echoes input', + inputSchema: z.object({ + message: z.string(), + }), + callback: async (input, context?: ToolContext) => { + toolReceivedState = context?.invocationState + return `Echo: ${input.message}` + }, + }) + + const agent = new Agent({ + model: createModel(), + tools: [echoTool], + hooks: [new TestHookProvider()], + printer: false, + }) + + await agent.invoke('Use echo tool with message "test"', { invocationState }) + + expect(events).toContain(`beforeInvoke:${JSON.stringify(invocationState)}`) + expect(toolReceivedState).toEqual(invocationState) + }) + + it('invocationState is not persisted between invocations', async () => { + const capturedStates: Array | undefined> = [] + + class TestHookProvider implements HookProvider { + registerCallbacks(registry: HookRegistry): void { + registry.addCallback(BeforeInvocationEvent, async (event: BeforeInvocationEvent) => { + capturedStates.push(event.invocationState) + }) + } + } + + const agent = new Agent({ + model: createModel(), + hooks: [new TestHookProvider()], + printer: false, + }) + + await agent.invoke('First request', { invocationState: { call: 1 } }) + await agent.invoke('Second request') // No invocationState + + expect(capturedStates).toHaveLength(2) + expect(capturedStates[0]).toEqual({ call: 1 }) + expect(capturedStates[1]).toBeUndefined() + }) + }) }) })