Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/agent/__tests__/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Agent, type ToolList } from '../agent.js'
import { MockMessageModel } from '../../__fixtures__/mock-message-model.js'
import { collectGenerator } from '../../__fixtures__/model-test-helpers.js'
import { createMockTool, createRandomTool } from '../../__fixtures__/tool-helpers.js'
import { MockHookProvider } from '../../__fixtures__/mock-hook-provider.js'
import { ConcurrentInvocationError } from '../../errors.js'
import {
MaxTokensError,
Expand All @@ -17,6 +18,8 @@ import {
ImageBlock,
VideoBlock,
DocumentBlock,
type ToolContext,
type Tool,
} from '../../index.js'
import { AgentPrinter } from '../printer.js'
import { BeforeInvocationEvent, BeforeToolsEvent } from '../../hooks/events.js'
Expand Down Expand Up @@ -878,4 +881,116 @@ describe('Agent', () => {
})
})
})

describe('invocationState', () => {
it('passes invocationState to BeforeInvocationEvent hook', async () => {
const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' })

const hookProvider = new MockHookProvider({ includeModelEvents: false })

const agent = new Agent({
model,
printer: false,
hooks: [hookProvider],
})

const testState = { userId: '123', sessionId: 'abc' }
await agent.invoke('Hello', { invocationState: testState })

const beforeEvent = hookProvider.invocations.find((e) => e instanceof BeforeInvocationEvent)
expect(beforeEvent).toBeDefined()
expect((beforeEvent as BeforeInvocationEvent)?.invocationState).toEqual(testState)
})

it('passes invocationState to tool context', async () => {
let capturedContext: ToolContext | undefined

const mockTool: Tool = {
name: 'testTool',
description: 'A test tool',
toolSpec: {
name: 'testTool',
description: 'A test tool',
inputSchema: { type: 'object', properties: {} },
},
stream: async function* (context: ToolContext) {
capturedContext = context
yield { type: 'toolStreamEvent' as const, data: 'processing' }
return new ToolResultBlock({
toolUseId: context.toolUse.toolUseId,
status: 'success',
content: [new TextBlock('done')],
})
},
}

const model = new MockMessageModel()
.addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} })
.addTurn({ type: 'textBlock', text: 'Done!' })

const agent = new Agent({
model,
tools: [mockTool],
printer: false,
})

const testState = { customData: 'value' }
await agent.invoke('Use tool', { invocationState: testState })

expect(capturedContext).toBeDefined()
expect(capturedContext?.invocationState).toEqual(testState)
})

it('handles undefined invocationState', async () => {
const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' })
const agent = new Agent({
model,
printer: false,
})

const result = await agent.invoke('Hello')
expect(result.stopReason).toBe('endTurn')
})

it('invocationState is not persisted between invocations', async () => {
const model = new MockMessageModel()
.addTurn({ type: 'textBlock', text: 'First response' })
.addTurn({ type: 'textBlock', text: 'Second response' })

const hookProvider = new MockHookProvider({ includeModelEvents: false })

const agent = new Agent({
model,
printer: false,
hooks: [hookProvider],
})

await agent.invoke('First', { invocationState: { request: 1 } })
await agent.invoke('Second') // No invocationState

const events = hookProvider.invocations.filter((e) => e instanceof BeforeInvocationEvent) as BeforeInvocationEvent[]
expect(events).toHaveLength(2)
expect(events[0]?.invocationState).toEqual({ request: 1 })
expect(events[1]?.invocationState).toBeUndefined()
})

it('stream method accepts invocationState option', async () => {
const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' })

const hookProvider = new MockHookProvider({ includeModelEvents: false })

const agent = new Agent({
model,
printer: false,
hooks: [hookProvider],
})

const testState = { streamTest: true }
const { result } = await collectGenerator(agent.stream('Hello', { invocationState: testState }))

const beforeEvent = hookProvider.invocations.find((e) => e instanceof BeforeInvocationEvent) as BeforeInvocationEvent | undefined
expect(beforeEvent?.invocationState).toEqual(testState)
expect(result.stopReason).toBe('endTurn')
})
})
})
Loading
Loading