diff --git a/.changeset/flat-garlics-knock.md b/.changeset/flat-garlics-knock.md new file mode 100644 index 0000000..5d3239f --- /dev/null +++ b/.changeset/flat-garlics-knock.md @@ -0,0 +1,5 @@ +--- +"ai-connector": patch +--- + +Implement new start-of-stream newline trimming diff --git a/packages/core/src/ai-stream.ts b/packages/core/src/ai-stream.ts index ffeaab7..7fe89dc 100644 --- a/packages/core/src/ai-stream.ts +++ b/packages/core/src/ai-stream.ts @@ -11,23 +11,17 @@ export interface AIStreamCallbacks { onToken?: (token: string) => Promise } -export interface AIStreamParserOptions { - data: any - counter: number -} - export interface AIStreamParser { - (opts: AIStreamParserOptions): string | void + (data: string): string | void } export function createEventStreamTransformer(customParser: AIStreamParser) { const decoder = new TextDecoder() - let counter = 0 let parser: EventSourceParser return new TransformStream({ async start(controller): Promise { - function onParse(event: ParsedEvent | ReconnectInterval): void { + function onParse(event: ParsedEvent | ReconnectInterval) { if (event.type === 'event') { const data = event.data if (data === '[DONE]') { @@ -35,9 +29,7 @@ export function createEventStreamTransformer(customParser: AIStreamParser) { return } - const message = customParser({ data, counter }) - counter++ - + const message = customParser(data) if (message) controller.enqueue(message) } } @@ -85,6 +77,18 @@ export function createCallbacksTransformer( }) } +// If we're still at the start of the stream, we want to trim the leading +// `\n\n`. But, after we've seen some text, we no longer want to trim out +// whitespace. +export function trimStartOfStreamHelper() { + let start = true + return (text: string) => { + if (start) text = text.trimStart() + if (text) start = false + return text + } +} + export function AIStream( res: Response, customParser: AIStreamParser, diff --git a/packages/core/src/anthropic-stream.ts b/packages/core/src/anthropic-stream.ts index 8d5aeb2..3fc1471 100644 --- a/packages/core/src/anthropic-stream.ts +++ b/packages/core/src/anthropic-stream.ts @@ -1,15 +1,9 @@ -import { - AIStream, - type AIStreamCallbacks, - type AIStreamParserOptions -} from './ai-stream' +import { AIStream, type AIStreamCallbacks } from './ai-stream' -function parseAnthropicStream(): ({ - data -}: AIStreamParserOptions) => string | void { +function parseAnthropicStream(): (data: string) => string | void { let previous = '' - return ({ data }) => { + return data => { const json = JSON.parse(data as string) as { completion: string stop: string | null diff --git a/packages/core/src/huggingface-stream.ts b/packages/core/src/huggingface-stream.ts index d4fcf9f..3456f1b 100644 --- a/packages/core/src/huggingface-stream.ts +++ b/packages/core/src/huggingface-stream.ts @@ -1,7 +1,11 @@ -import { type AIStreamCallbacks, createCallbacksTransformer } from './ai-stream' +import { + type AIStreamCallbacks, + createCallbacksTransformer, + trimStartOfStreamHelper +} from './ai-stream' function createParser(res: AsyncGenerator) { - let counter = 0 + const trimStartOfStream = trimStartOfStreamHelper() return new ReadableStream({ async pull(controller): Promise { const { value, done } = await res.next() @@ -10,7 +14,8 @@ function createParser(res: AsyncGenerator) { return } - const text: string = value.token?.text ?? '' + const text = trimStartOfStream(value.token?.text ?? '') + if (!text) return // some HF models return generated_text instead of a real ending token if (value.generated_text != null && value.generated_text.length > 0) { @@ -20,16 +25,10 @@ function createParser(res: AsyncGenerator) { // <|endoftext|> is for https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5 // is also often last token in the stream depending on the model - if (text !== '' && text !== '<|endoftext|>') { - // TODO: Is this needed? - if (counter < 2 && text.includes('\n')) { - return - } - - controller.enqueue(text) - counter++ - } else { + if (text === '' || text === '<|endoftext|>') { controller.close() + } else { + controller.enqueue(text) } } }) diff --git a/packages/core/src/openai-stream.ts b/packages/core/src/openai-stream.ts index 7499278..3c117c8 100644 --- a/packages/core/src/openai-stream.ts +++ b/packages/core/src/openai-stream.ts @@ -1,30 +1,27 @@ import { AIStream, - type AIStreamCallbacks, - type AIStreamParserOptions + trimStartOfStreamHelper, + type AIStreamCallbacks } from './ai-stream' -function parseOpenAIStream({ - data, - counter -}: AIStreamParserOptions): string | void { - // TODO: Needs a type - const json = JSON.parse(data) +function parseOpenAIStream(): (data: string) => string | void { + const trimStartOfStream = trimStartOfStreamHelper() + return data => { + // TODO: Needs a type + const json = JSON.parse(data) - // this can be used for either chat or completion models - const text = json.choices[0]?.delta?.content ?? json.choices[0]?.text ?? '' + // this can be used for either chat or completion models + const text = trimStartOfStream( + json.choices[0]?.delta?.content ?? json.choices[0]?.text ?? '' + ) - // TODO: I don't understand the `counter && has newline`. Should this be `counter < 2 || !has newline?`? - if (counter < 2 && text.includes('\n')) { - return + return text } - - return text } export function OpenAIStream( res: Response, cb?: AIStreamCallbacks ): ReadableStream { - return AIStream(res, parseOpenAIStream, cb) + return AIStream(res, parseOpenAIStream(), cb) }