diff --git a/.changeset/tasty-bobcats-check.md b/.changeset/tasty-bobcats-check.md new file mode 100644 index 0000000..7953393 --- /dev/null +++ b/.changeset/tasty-bobcats-check.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +StreamData: add `annotations` and `appendMessageAnnotation` support diff --git a/packages/core/shared/parse-complex-response.test.ts b/packages/core/shared/parse-complex-response.test.ts index 9a7d91b..3f8181c 100644 --- a/packages/core/shared/parse-complex-response.test.ts +++ b/packages/core/shared/parse-complex-response.test.ts @@ -231,8 +231,13 @@ describe('parseComplexResponse function', () => { // Execute the parser function const result = await parseComplexResponse({ reader: createTestReader([ + + '0:"Sample text message."\n', + '8:[{"key":"value"}, 2]\n', + '8:[{"key":"value"}, 2]\n', '0:"Sample text message."\n', + ]), abortControllerRef: { current: new AbortController() }, update: mockUpdate, @@ -243,8 +248,14 @@ describe('parseComplexResponse function', () => { // check the mockUpdate call: expect(mockUpdate).toHaveBeenCalledTimes(2); + + expect(mockUpdate.mock.calls[0][0]).toEqual([ + assistantTextMessage('Sample text message.'), + ]); + expect(mockUpdate.mock.calls[0][0]).toEqual([]); + expect(mockUpdate.mock.calls[1][0]).toEqual([ { ...assistantTextMessage('Sample text message.'), @@ -264,6 +275,8 @@ describe('parseComplexResponse function', () => { }); }); + + it('should parse a combination of a function_call and message annotations', async () => { const mockUpdate = vi.fn(); @@ -350,4 +363,5 @@ describe('parseComplexResponse function', () => { data: [], }); }); + }); diff --git a/packages/core/shared/parse-complex-response.ts b/packages/core/shared/parse-complex-response.ts index db3de19..2fd00c2 100644 --- a/packages/core/shared/parse-complex-response.ts +++ b/packages/core/shared/parse-complex-response.ts @@ -15,12 +15,29 @@ type PrefixMap = { data: JSONValue[]; }; + +function initializeMessage({ + generateId, + ...rest +}: { + generateId: () => string; + content: string; + createdAt: Date; + annotations?: JSONValue[]; +}): Message { + return { + id: generateId(), + role: 'assistant', + ...rest + }; + function assignAnnotationsToMessage( message: T, annotations: JSONValue[] | undefined, ): T { if (!message || !annotations || !annotations.length) return message; return { ...message, annotations: [...annotations] } as T; + } export async function parseComplexResponse({ @@ -63,7 +80,24 @@ export async function parseComplexResponse({ id: generateId(), role: 'assistant', content: value, - createdAt, + createdAt + }; + } + } + + if (type == 'message_annotations') { + if (prefixMap['text']) { + prefixMap['text'] = { + ...prefixMap['text'], + annotations: [...prefixMap['text'].annotations || [], ...value], + }; + } else { + prefixMap['text'] = { + id: generateId(), + role: 'assistant', + content: '', + annotations: [...value], + createdAt }; } } diff --git a/packages/core/shared/types.ts b/packages/core/shared/types.ts index eb7efd0..8a0fc96 100644 --- a/packages/core/shared/types.ts +++ b/packages/core/shared/types.ts @@ -112,7 +112,10 @@ export interface Message { tool_calls?: string | ToolCall[]; /** - * Additional message-specific information added on the server via StreamData + + * Additional message-specific information added on the server via StreamData + * Additional message-specific information added on the server via StreamData + */ annotations?: JSONValue[] | undefined; } @@ -315,7 +318,7 @@ export type UseCompletionOptions = { body?: object } -======= + body?: object; }; @@ -328,7 +331,7 @@ export type JSONValue = | { [x: string]: JSONValue } | Array -======= + | Array; export type AssistantMessage = { diff --git a/packages/core/streams/ai-stream.ts b/packages/core/streams/ai-stream.ts index 36e3615..29d5ff0 100644 --- a/packages/core/streams/ai-stream.ts +++ b/packages/core/streams/ai-stream.ts @@ -42,7 +42,7 @@ export interface AIStreamCallbacks { export interface AIStreamCallbacksAndOptions extends AIStreamCallbacks { streamData?: Data -======= + export interface AIStreamCallbacksAndOptions { /** `onStart`: Called once when the stream is initialized. */ onStart?: () => Promise | void; @@ -81,10 +81,14 @@ export interface AIStreamParserOptions { * @returns {string | void} The parsed data or void. */ export interface AIStreamParser { + + (data: string, options: AIStreamParserOptions): string | void; + (data: string, options: AIStreamParserOptions): | string | void | { isText: false; content: string }; + } /** @@ -124,8 +128,13 @@ export function createEventStreamTransformer( const parsedMessage = customParser ? customParser(event.data, { + + event: event.event + }) + event: event.event, }) + : event.data; if (parsedMessage) controller.enqueue(parsedMessage); diff --git a/packages/core/streams/stream-data.ts b/packages/core/streams/stream-data.ts index 785db17..62145bc 100644 --- a/packages/core/streams/stream-data.ts +++ b/packages/core/streams/stream-data.ts @@ -42,11 +42,18 @@ export class experimental_StreamData { } if (self.messageAnnotations.length) { + + const encodedmessageAnnotations = self.encoder.encode( + formatStreamPart('message_annotations', self.messageAnnotations), + ); + controller.enqueue(encodedmessageAnnotations); + const encodedMessageAnnotations = self.encoder.encode( formatStreamPart('message_annotations', self.messageAnnotations), ); self.messageAnnotations = []; controller.enqueue(encodedMessageAnnotations); + } controller.enqueue(chunk);