diff --git a/packages/core/shared/types.ts b/packages/core/shared/types.ts index be7be26..eb7efd0 100644 --- a/packages/core/shared/types.ts +++ b/packages/core/shared/types.ts @@ -312,15 +312,23 @@ export type UseCompletionOptions = { * }) * ``` */ + + body?: object +} +======= body?: object; }; + export type JSONValue = | null | string | number | boolean | { [x: string]: JSONValue } + + | Array +======= | Array; export type AssistantMessage = { @@ -346,3 +354,4 @@ export type DataMessage = { role: 'data'; data: JSONValue; // application-specific data }; + diff --git a/packages/core/shared/utils.ts b/packages/core/shared/utils.ts index 8cca263..0355419 100644 --- a/packages/core/shared/utils.ts +++ b/packages/core/shared/utils.ts @@ -1,3 +1,7 @@ + +import { customAlphabet } from 'nanoid/non-secure' +import { JSONValue } from './types' + import { customAlphabet } from 'nanoid/non-secure'; import { StreamPartType, @@ -5,6 +9,7 @@ import { parseStreamPart, } from './stream-parts'; + // 7-character random string export const nanoid = customAlphabet( '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', @@ -44,6 +49,47 @@ function createChunkDecoder(complex?: boolean) { }; } + +/** + * The map of prefixes for data in the stream + * + * - 0: Text from the LLM response + * - 1: (OpenAI) function_call responses + * - 2: custom JSON added by the user using `Data` + * + * Example: + * ``` + * 0:Vercel + * 0:'s + * 0: AI + * 0: AI + * 0: SDK + * 0: is great + * 0:! + * 2: { "someJson": "value" } + * 1: {"function_call": {"name": "get_current_weather", "arguments": "{\\n\\"location\\": \\"Charlottesville, Virginia\\",\\n\\"format\\": \\"celsius\\"\\n}"}} + *``` + */ +export const StreamStringPrefixes = { + text: 0, + function_call: 1, + data: 2 +} as const + +/** + * Prepends a string with a prefix from the `StreamChunkPrefixes`, JSON-ifies it, and appends a new line. + */ +export const getStreamString = ( + type: keyof typeof StreamStringPrefixes, + value: JSONValue +): StreamString => + `${StreamStringPrefixes[type]}:${ + typeof value === 'string' ? value : JSON.stringify(value) + }\n` + +export type StreamString = + `${(typeof StreamStringPrefixes)[keyof typeof StreamStringPrefixes]}:${string}\n` + export { createChunkDecoder }; export const isStreamStringEqualToType = ( @@ -59,3 +105,4 @@ export type StreamString = * A header sent to the client so it knows how to handle parsing the stream (as a deprecated text response or using the new prefixed protocol) */ export const COMPLEX_HEADER = 'X-Experimental-Stream-Data'; + diff --git a/packages/core/streams/ai-stream.ts b/packages/core/streams/ai-stream.ts index 754d2c8..36e3615 100644 --- a/packages/core/streams/ai-stream.ts +++ b/packages/core/streams/ai-stream.ts @@ -2,10 +2,17 @@ import { createParser, type EventSourceParser, type ParsedEvent, + + type ReconnectInterval +} from 'eventsource-parser' +import { Data } from './data-stream' +import { getStreamString } from '../shared/utils' +======= type ReconnectInterval, } from 'eventsource-parser'; import { OpenAIStreamCallbacks } from './openai-stream'; + export interface FunctionCallPayload { name: string; arguments: Record; @@ -25,6 +32,17 @@ export interface ToolCallPayload { * Configuration options and helper callback methods for AIStream stream lifecycle events. * @interface */ + +export interface AIStreamCallbacks { + onStart?: () => Promise | void + onCompletion?: (completion: string) => Promise | void + onToken?: (token: string) => Promise | void + streamData?: Data +} + +export interface AIStreamCallbacksAndOptions extends AIStreamCallbacks { + streamData?: Data +======= export interface AIStreamCallbacksAndOptions { /** `onStart`: Called once when the stream is initialized. */ onStart?: () => Promise | void; @@ -52,6 +70,7 @@ export interface AIStreamCallbacksAndOptions { */ export interface AIStreamParserOptions { event?: string; + } /** @@ -96,12 +115,20 @@ export function createEventStreamTransformer( } if ('data' in event) { + + const parsedMessage = customParser(event.data) + if (parsedMessage) + controller.enqueue( + getStreamString('text', parsedMessage) + ) + const parsedMessage = customParser ? customParser(event.data, { event: event.event, }) : event.data; if (parsedMessage) controller.enqueue(parsedMessage); + } }, ); @@ -123,7 +150,11 @@ export function createEventStreamTransformer( * * This function is useful when you want to process a stream of messages and perform specific actions during the stream's lifecycle. * + + * @param {AIStreamCallbacksAndOptions} [callbacksAndOptions] - An object containing the callback functions. +======= * @param {AIStreamCallbacksAndOptions} [callbacks] - An object containing the callback functions. + * @return {TransformStream} A transform stream that encodes input messages as Uint8Array and allows the execution of custom logic through callbacks. * * @example @@ -135,6 +166,14 @@ export function createEventStreamTransformer( * }; * const transformer = createCallbacksTransformer(callbacks); */ + +export function createCallbacksAndOptionsTransformer( + callbacks: AIStreamCallbacksAndOptions | undefined +): TransformStream { + const textEncoder = new TextEncoder() + let aggregatedResponse = '' + const { onStart, onToken, onCompletion } = callbacks || {} +======= export function createCallbacksTransformer( cb: AIStreamCallbacksAndOptions | OpenAIStreamCallbacks | undefined, ): TransformStream { @@ -142,6 +181,7 @@ export function createCallbacksTransformer( let aggregatedResponse = ''; const callbacks = cb || {}; + return new TransformStream({ async start(): Promise { if (callbacks.onStart) await callbacks.onStart(); @@ -268,7 +308,11 @@ export function AIStream( return responseBodyStream .pipeThrough(createEventStreamTransformer(customParser)) + + .pipeThrough(createCallbacksAndOptionsTransformer(callbacks)) + .pipeThrough(createCallbacksTransformer(callbacks)); + } // outputs lines like diff --git a/packages/core/streams/cohere-stream.ts b/packages/core/streams/cohere-stream.ts index 262b365..6fc21a0 100644 --- a/packages/core/streams/cohere-stream.ts +++ b/packages/core/streams/cohere-stream.ts @@ -1,3 +1,6 @@ + +import { type AIStreamCallbacks, createCallbacksAndOptionsTransformer } from './ai-stream' +======= import { type AIStreamCallbacksAndOptions, createCallbacksTransformer, @@ -5,6 +8,7 @@ import { } from './ai-stream'; import { createStreamDataTransformer } from './stream-data'; + const utf8Decoder = new TextDecoder('utf-8'); // Full types @@ -90,6 +94,9 @@ export function CohereStream( reader: Response | AsyncIterable, callbacks?: AIStreamCallbacksAndOptions, ): ReadableStream { + + return createParser(reader).pipeThrough(createCallbacksAndOptionsTransformer(callbacks)) +======= if (Symbol.asyncIterator in reader) { return readableFromAsyncIterable(streamable(reader)) .pipeThrough(createCallbacksTransformer(callbacks)) @@ -103,4 +110,5 @@ export function CohereStream( createStreamDataTransformer(callbacks?.experimental_streamData), ); } + } diff --git a/packages/core/streams/data-stream.ts b/packages/core/streams/data-stream.ts new file mode 100644 index 0000000..76f2a68 --- /dev/null +++ b/packages/core/streams/data-stream.ts @@ -0,0 +1,37 @@ +import { JSONValue } from '../shared/types' +import { getStreamString } from '../shared/utils' + +/** + * A stream wrapper to send custom JSON-encoded data back to the client. + */ +export class Data { + private encoder = new TextEncoder() + private controller: ReadableStreamDefaultController | null = null + private stream: ReadableStream + + constructor() { + this.stream = new ReadableStream({ + start: controller => { + this.controller = controller + } + }) + } + + append(value: JSONValue, prefix: string = '0'): void { + if (!this.controller) { + throw new Error('Stream controller is not initialized.') + } + + const textEncoder = new TextEncoder() + this.controller.enqueue( + textEncoder.encode(getStreamString('text', JSON.stringify(value))) + ) + } + + close() { + if (!this.controller) return + + this.controller.close() + this.controller = null + } +} diff --git a/packages/core/streams/huggingface-stream.ts b/packages/core/streams/huggingface-stream.ts index 464d098..73dfc42 100644 --- a/packages/core/streams/huggingface-stream.ts +++ b/packages/core/streams/huggingface-stream.ts @@ -1,10 +1,17 @@ import { + + type AIStreamCallbacks, + createCallbacksAndOptionsTransformer, + trimStartOfStreamHelper +} from './ai-stream' +======= type AIStreamCallbacksAndOptions, createCallbacksTransformer, trimStartOfStreamHelper, } from './ai-stream'; import { createStreamDataTransformer } from './stream-data'; + function createParser(res: AsyncGenerator) { const trimStartOfStream = trimStartOfStreamHelper(); return new ReadableStream({ @@ -40,9 +47,13 @@ export function HuggingFaceStream( res: AsyncGenerator, callbacks?: AIStreamCallbacksAndOptions, ): ReadableStream { + + return createParser(res).pipeThrough(createCallbacksAndOptionsTransformer(callbacks)) +======= return createParser(res) .pipeThrough(createCallbacksTransformer(callbacks)) .pipeThrough( createStreamDataTransformer(callbacks?.experimental_streamData), ); + } diff --git a/packages/core/streams/langchain-stream.ts b/packages/core/streams/langchain-stream.ts index 2256dc7..4d9386e 100644 --- a/packages/core/streams/langchain-stream.ts +++ b/packages/core/streams/langchain-stream.ts @@ -1,9 +1,13 @@ + +import { type AIStreamCallbacks, createCallbacksAndOptionsTransformer } from './ai-stream' +======= import { type AIStreamCallbacksAndOptions, createCallbacksTransformer, } from './ai-stream'; import { createStreamDataTransformer } from './stream-data'; + export function LangChainStream(callbacks?: AIStreamCallbacksAndOptions) { const stream = new TransformStream(); const writer = stream.writable.getWriter(); @@ -30,12 +34,16 @@ export function LangChainStream(callbacks?: AIStreamCallbacksAndOptions) { }; return { + + stream: stream.readable.pipeThrough(createCallbacksAndOptionsTransformer(callbacks)), + stream: stream.readable .pipeThrough(createCallbacksTransformer(callbacks)) .pipeThrough( createStreamDataTransformer(callbacks?.experimental_streamData), ), writer, + handlers: { handleLLMNewToken: async (token: string) => { await writer.ready; diff --git a/packages/core/streams/openai-stream.ts b/packages/core/streams/openai-stream.ts index 4dd7c46..7b68c7b 100644 --- a/packages/core/streams/openai-stream.ts +++ b/packages/core/streams/openai-stream.ts @@ -1,3 +1,7 @@ + +import { CreateMessage, JSONValue } from '../shared/types' +import { getStreamString } from '../shared/utils' + import { formatStreamPart } from '../shared/stream-parts'; import { CreateMessage, @@ -7,9 +11,17 @@ import { } from '../shared/types'; import { createChunkDecoder } from '../shared/utils'; + import { AIStream, trimStartOfStreamHelper, + + type AIStreamCallbacks, + FunctionCallPayload +} from './ai-stream' + +export type OpenAIStreamCallbacks = AIStreamCallbacks & { + type AIStreamCallbacksAndOptions, FunctionCallPayload, readableFromAsyncIterable, @@ -20,6 +32,7 @@ import { AzureChatCompletions } from './azure-openai-types'; import { createStreamDataTransformer } from './stream-data'; export type OpenAIStreamCallbacks = AIStreamCallbacksAndOptions & { + /** * @example * ```js @@ -661,6 +674,43 @@ function createFunctionCallTransformer( } } + ) + + if (!functionResponse) { + // The user didn't do anything with the function call on the server and wants + // to either do nothing or run it on the client + // so we just return the function call as a message + controller.enqueue( + textEncoder.encode( + getStreamString('function_call', aggregatedResponse) + ) + ) + return + } else if (typeof functionResponse === 'string') { + // The user returned a string, so we just return it as a message + controller.enqueue( + textEncoder.encode(getStreamString('text', functionResponse)) + ) + return + } + + // Recursively: + + // We don't want to trigger onStart or onComplete recursively + // so we remove them from the callbacks + // see https://github.com/vercel-labs/ai/issues/351 + const filteredCallbacks: OpenAIStreamCallbacks = { + ...callbacks, + onStart: undefined, + onCompletion: undefined + } + + const openAIStream = OpenAIStream(functionResponse, { + ...filteredCallbacks, + [__internal__OpenAIFnMessagesSymbol]: newFunctionCallMessages + } as AIStreamCallbacks) + + if (!functionResponse) { // The user didn't do anything with the function call on the server and wants // to either do nothing or run it on the client