Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions packages/components/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,27 @@
}
]
},
{
"name": "chatBaiduWenxin",
"models": [
{
"label": "ernie-4.5-8k-preview",
"name": "ernie-4.5-8k-preview"
},
{
"label": "ernie-4.0-8k",
"name": "ernie-4.0-8k"
},
{
"label": "ernie-3.5-8k-preview",
"name": "ernie-3.5-8k-preview"
},
{
"label": "ernie-speed-128k",
"name": "ernie-speed-128k"
}
]
},
{
"name": "chatAlibabaTongyi",
"models": [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
jest.mock('@langchain/baidu-qianfan', () => ({
ChatBaiduQianfan: jest.fn().mockImplementation((fields) => ({ fields }))
}))

jest.mock('../../../src/utils', () => ({
getBaseClasses: jest.fn().mockReturnValue(['BaseChatModel']),
getCredentialData: jest.fn(),
getCredentialParam: jest.fn()
}))

jest.mock('../../../src/modelLoader', () => ({
MODEL_TYPE: { CHAT: 'chat' },
getModels: jest.fn()
}))

import { getCredentialData, getCredentialParam } from '../../../src/utils'
import { getModels } from '../../../src/modelLoader'

const { nodeClass: ChatBaiduWenxin } = require('./ChatBaiduWenxin')

describe('ChatBaiduWenxin', () => {
beforeEach(() => {
jest.clearAllMocks()
})

it('loads model options from the shared model loader', async () => {
;(getModels as jest.Mock).mockResolvedValue([{ label: 'ernie-4.5-8k-preview', name: 'ernie-4.5-8k-preview' }])

const node = new ChatBaiduWenxin()
const models = await node.loadMethods.listModels()

expect(getModels).toHaveBeenCalledWith('chat', 'chatBaiduWenxin')
expect(models).toEqual([{ label: 'ernie-4.5-8k-preview', name: 'ernie-4.5-8k-preview' }])
})

it('passes advanced settings and custom model names to ChatBaiduQianfan', async () => {
;(getCredentialData as jest.Mock).mockResolvedValue({
qianfanAccessKey: 'access-key',
qianfanSecretKey: 'secret-key'
})
;(getCredentialParam as jest.Mock).mockImplementation((key, credentialData) => credentialData[key])

const node = new ChatBaiduWenxin()
const model = await node.init(
{
credential: 'cred-1',
inputs: {
modelName: 'ernie-4.0-8k',
customModelName: 'ernie-speed-128k',
temperature: '0.2',
streaming: false,
topP: '0.8',
penaltyScore: '1.4',
userId: 'user-123'
}
},
'',
{}
)

expect(model.fields).toMatchObject({
qianfanAccessKey: 'access-key',
qianfanSecretKey: 'secret-key',
modelName: 'ernie-speed-128k',
temperature: 0.2,
streaming: false,
topP: 0.8,
penaltyScore: 1.4,
userId: 'user-123'
})
})
})
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { BaseCache } from '@langchain/core/caches'
import { ChatBaiduQianfan } from '@langchain/baidu-qianfan'
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { ICommonObject, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface'
import { MODEL_TYPE, getModels } from '../../../src/modelLoader'
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'

class ChatBaiduWenxin_ChatModels implements INode {
Expand All @@ -18,7 +19,7 @@ class ChatBaiduWenxin_ChatModels implements INode {
constructor() {
this.label = 'Baidu Wenxin'
this.name = 'chatBaiduWenxin'
this.version = 2.0
this.version = 3.0
this.type = 'ChatBaiduWenxin'
this.icon = 'baiduwenxin.svg'
this.category = 'Chat Models'
Expand All @@ -38,10 +39,20 @@ class ChatBaiduWenxin_ChatModels implements INode {
optional: true
},
{
label: 'Model',
label: 'Model Name',
name: 'modelName',
type: 'asyncOptions',
loadMethod: 'listModels',
default: 'ernie-4.5-8k-preview'
},
{
label: 'Custom Model Name',
name: 'customModelName',
type: 'string',
placeholder: 'ERNIE-Bot-turbo'
placeholder: 'ernie-speed-128k',
description: 'Custom model name to use. If provided, it will override the selected model.',
additionalParams: true,
optional: true
},
{
label: 'Temperature',
Expand All @@ -57,15 +68,52 @@ class ChatBaiduWenxin_ChatModels implements INode {
type: 'boolean',
default: true,
optional: true
},
{
label: 'Top Probability',
name: 'topP',
type: 'number',
description: 'Nucleus sampling. The model considers tokens whose cumulative probability mass reaches this value.',
step: 0.1,
optional: true,
additionalParams: true
},
{
label: 'Penalty Score',
name: 'penaltyScore',
type: 'number',
description: 'Penalizes repeated tokens according to frequency. Baidu Qianfan accepts values from 1.0 to 2.0.',
step: 0.1,
optional: true,
additionalParams: true
},
{
label: 'User ID',
name: 'userId',
type: 'string',
description: 'Optional unique identifier for the end user making the request.',
optional: true,
additionalParams: true
}
]
}

//@ts-ignore
loadMethods = {
async listModels(): Promise<INodeOptionsValue[]> {
return await getModels(MODEL_TYPE.CHAT, 'chatBaiduWenxin')
}
}

async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
const cache = nodeData.inputs?.cache as BaseCache
const temperature = nodeData.inputs?.temperature as string
const modelName = nodeData.inputs?.modelName as string
const customModelName = nodeData.inputs?.customModelName as string
const streaming = nodeData.inputs?.streaming as boolean
const topP = nodeData.inputs?.topP as string
const penaltyScore = nodeData.inputs?.penaltyScore as string
const userId = nodeData.inputs?.userId as string

const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const qianfanAccessKey = getCredentialParam('qianfanAccessKey', credentialData, nodeData)
Expand All @@ -75,9 +123,12 @@ class ChatBaiduWenxin_ChatModels implements INode {
streaming: streaming ?? true,
qianfanAccessKey,
qianfanSecretKey,
modelName,
modelName: customModelName || modelName,
temperature: temperature ? parseFloat(temperature) : undefined
}
if (topP) obj.topP = parseFloat(topP)
if (penaltyScore) obj.penaltyScore = parseFloat(penaltyScore)
if (userId) obj.userId = userId
if (cache) obj.cache = cache

const model = new ChatBaiduQianfan(obj)
Expand Down