diff --git a/packages/components/models.json b/packages/components/models.json index 6be19c3157f..d71c15b60bf 100644 --- a/packages/components/models.json +++ b/packages/components/models.json @@ -815,6 +815,27 @@ } ] }, + { + "name": "chatBaiduWenxin", + "models": [ + { + "label": "ernie-4.5-8k-preview", + "name": "ernie-4.5-8k-preview" + }, + { + "label": "ernie-4.0-8k", + "name": "ernie-4.0-8k" + }, + { + "label": "ernie-3.5-8k-preview", + "name": "ernie-3.5-8k-preview" + }, + { + "label": "ernie-speed-128k", + "name": "ernie-speed-128k" + } + ] + }, { "name": "chatAlibabaTongyi", "models": [ diff --git a/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.test.ts b/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.test.ts new file mode 100644 index 00000000000..eb8ce6913ea --- /dev/null +++ b/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.test.ts @@ -0,0 +1,72 @@ +jest.mock('@langchain/baidu-qianfan', () => ({ + ChatBaiduQianfan: jest.fn().mockImplementation((fields) => ({ fields })) +})) + +jest.mock('../../../src/utils', () => ({ + getBaseClasses: jest.fn().mockReturnValue(['BaseChatModel']), + getCredentialData: jest.fn(), + getCredentialParam: jest.fn() +})) + +jest.mock('../../../src/modelLoader', () => ({ + MODEL_TYPE: { CHAT: 'chat' }, + getModels: jest.fn() +})) + +import { getCredentialData, getCredentialParam } from '../../../src/utils' +import { getModels } from '../../../src/modelLoader' + +const { nodeClass: ChatBaiduWenxin } = require('./ChatBaiduWenxin') + +describe('ChatBaiduWenxin', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('loads model options from the shared model loader', async () => { + ;(getModels as jest.Mock).mockResolvedValue([{ label: 'ernie-4.5-8k-preview', name: 'ernie-4.5-8k-preview' }]) + + const node = new ChatBaiduWenxin() + const models = await node.loadMethods.listModels() + + expect(getModels).toHaveBeenCalledWith('chat', 'chatBaiduWenxin') + expect(models).toEqual([{ label: 'ernie-4.5-8k-preview', name: 'ernie-4.5-8k-preview' }]) + }) + + it('passes advanced settings and custom model names to ChatBaiduQianfan', async () => { + ;(getCredentialData as jest.Mock).mockResolvedValue({ + qianfanAccessKey: 'access-key', + qianfanSecretKey: 'secret-key' + }) + ;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key]) + + const node = new ChatBaiduWenxin() + const model = await node.init( + { + credential: 'cred-1', + inputs: { + modelName: 'ernie-4.0-8k', + customModelName: 'ernie-speed-128k', + temperature: '0.2', + streaming: false, + topP: '0.8', + penaltyScore: '1.4', + userId: 'user-123' + } + }, + '', + {} + ) + + expect(model.fields).toMatchObject({ + qianfanAccessKey: 'access-key', + qianfanSecretKey: 'secret-key', + modelName: 'ernie-speed-128k', + temperature: 0.2, + streaming: false, + topP: 0.8, + penaltyScore: 1.4, + userId: 'user-123' + }) + }) +}) diff --git a/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.ts b/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.ts index 01517144d0d..66cf17efff6 100644 --- a/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.ts +++ b/packages/components/nodes/chatmodels/ChatBaiduWenxin/ChatBaiduWenxin.ts @@ -1,6 +1,7 @@ import { BaseCache } from '@langchain/core/caches' import { ChatBaiduQianfan } from '@langchain/baidu-qianfan' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface' +import { MODEL_TYPE, getModels } from '../../../src/modelLoader' import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' class ChatBaiduWenxin_ChatModels implements INode { @@ -18,7 +19,7 @@ class ChatBaiduWenxin_ChatModels implements INode { constructor() { this.label = 'Baidu Wenxin' this.name = 'chatBaiduWenxin' - this.version = 2.0 + this.version = 3.0 this.type = 'ChatBaiduWenxin' this.icon = 'baiduwenxin.svg' this.category = 'Chat Models' @@ -38,10 +39,20 @@ class ChatBaiduWenxin_ChatModels implements INode { optional: true }, { - label: 'Model', + label: 'Model Name', name: 'modelName', + type: 'asyncOptions', + loadMethod: 'listModels', + default: 'ernie-4.5-8k-preview' + }, + { + label: 'Custom Model Name', + name: 'customModelName', type: 'string', - placeholder: 'ERNIE-Bot-turbo' + placeholder: 'ernie-speed-128k', + description: 'Custom model name to use. If provided, it will override the selected model.', + additionalParams: true, + optional: true }, { label: 'Temperature', @@ -57,15 +68,52 @@ class ChatBaiduWenxin_ChatModels implements INode { type: 'boolean', default: true, optional: true + }, + { + label: 'Top Probability', + name: 'topP', + type: 'number', + description: 'Nucleus sampling. The model considers tokens whose cumulative probability mass reaches this value.', + step: 0.1, + optional: true, + additionalParams: true + }, + { + label: 'Penalty Score', + name: 'penaltyScore', + type: 'number', + description: 'Penalizes repeated tokens according to frequency. Baidu Qianfan accepts values from 1.0 to 2.0.', + step: 0.1, + optional: true, + additionalParams: true + }, + { + label: 'User ID', + name: 'userId', + type: 'string', + description: 'Optional unique identifier for the end user making the request.', + optional: true, + additionalParams: true } ] } + //@ts-ignore + loadMethods = { + async listModels(): Promise { + return await getModels(MODEL_TYPE.CHAT, 'chatBaiduWenxin') + } + } + async init(nodeData: INodeData, _: string, options: ICommonObject): Promise { const cache = nodeData.inputs?.cache as BaseCache const temperature = nodeData.inputs?.temperature as string const modelName = nodeData.inputs?.modelName as string + const customModelName = nodeData.inputs?.customModelName as string const streaming = nodeData.inputs?.streaming as boolean + const topP = nodeData.inputs?.topP as string + const penaltyScore = nodeData.inputs?.penaltyScore as string + const userId = nodeData.inputs?.userId as string const credentialData = await getCredentialData(nodeData.credential ?? '', options) const qianfanAccessKey = getCredentialParam('qianfanAccessKey', credentialData, nodeData) @@ -75,9 +123,12 @@ class ChatBaiduWenxin_ChatModels implements INode { streaming: streaming ?? true, qianfanAccessKey, qianfanSecretKey, - modelName, + modelName: customModelName || modelName, temperature: temperature ? parseFloat(temperature) : undefined } + if (topP) obj.topP = parseFloat(topP) + if (penaltyScore) obj.penaltyScore = parseFloat(penaltyScore) + if (userId) obj.userId = userId if (cache) obj.cache = cache const model = new ChatBaiduQianfan(obj)