diff --git a/README.md b/README.md index da85043c4..181e45eaa 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ For more information: [chatboxai.app](https://chatboxai.app/) - Google Gemini Pro - Ollama (enable access to local models like llama2, Mistral, Mixtral, codellama, vicuna, yi, and solar) - ChatGLM-6B + - Volcengine Ark - **Image Generation with Dall-E-3** :art: Create the images of your imagination with Dall-E-3. diff --git a/src/renderer/components/ArkModelSelect.tsx b/src/renderer/components/ArkModelSelect.tsx new file mode 100644 index 000000000..fc34a65ff --- /dev/null +++ b/src/renderer/components/ArkModelSelect.tsx @@ -0,0 +1,46 @@ +import { Select, MenuItem, FormControl, InputLabel, TextField } from '@mui/material' +import { ModelSettings } from '../../shared/types' +import { useTranslation } from 'react-i18next' +import { models } from '../packages/models/ark' + +export interface Props { + arkModel: ModelSettings['arkModel'] + arkEndpointId: ModelSettings['arkEndpointId'] + onChange(arkModel: ModelSettings['arkModel'], arkEndpointId: ModelSettings['arkEndpointId']): void + className?: string +} + +export default function ArkModelSelect(props: Props) { + const { t } = useTranslation() + return ( + + {t('model')} + + + props.onChange(props.arkModel, e.target.value.trim()) + } + /> + + ) +} diff --git a/src/renderer/i18n/locales/zh-Hans/translation.json b/src/renderer/i18n/locales/zh-Hans/translation.json index b2ef515aa..2b26de52f 100644 --- a/src/renderer/i18n/locales/zh-Hans/translation.json +++ b/src/renderer/i18n/locales/zh-Hans/translation.json @@ -142,6 +142,7 @@ "View More Plans": "查看更多方案", "Custom Model": "自定义模型", "Custom Model Name": "自定义模型名", + "Endpoint": "接入点", "advanced": "其他", "Network Proxy": "Network Proxy", "Proxy Address": "Proxy Address", diff --git a/src/renderer/packages/models/ark.ts b/src/renderer/packages/models/ark.ts new file mode 100644 index 000000000..87a475e19 --- /dev/null +++ b/src/renderer/packages/models/ark.ts @@ -0,0 +1,172 @@ +import { Message } from 'src/shared/types' +import { ApiError, ChatboxAIAPIError } from './errors' +import Base, { onResultChange } from './base' + +interface Options { + arkApiKey: string + arkBaseURL: string + arkModel: ArkModel | 'custom-model' + arkEndpointId: string + temperature: number + topP: number +} + +export default class VolcArk extends Base { + public name = 'VolcengineArk' + + public options: Options + constructor(options: Options) { + super() + this.options = options + this.options.arkBaseURL = this.options.arkBaseURL || 'https://ark.cn-beijing.volces.com/api/v3' + } + + async callChatCompletion( + rawMessages: Message[], + signal?: AbortSignal, + onResultChange?: onResultChange + ): Promise { + try { + return await this._callChatCompletion(rawMessages, signal, onResultChange) + } catch (e) { + if ( + e instanceof ApiError && + e.message.includes('Invalid content type. image_url is only supported by certain models.') + ) { + throw ChatboxAIAPIError.fromCodeName('model_not_support_image', 'model_not_support_image') + } + throw e + } + } + + async _callChatCompletion( + rawMessages: Message[], + signal?: AbortSignal, + onResultChange?: onResultChange + ): Promise { + const model = this.options.arkEndpointId + + rawMessages = injectModelSystemPrompt(rawMessages) + + const messages = await populateGPTMessage(rawMessages) + return this.requestChatCompletionsStream( + { + messages, + model, + max_tokens: undefined, + temperature: this.options.temperature, + top_p: this.options.topP, + stream: true, + }, + signal, + onResultChange + ) + } + + async requestChatCompletionsStream( + requestBody: Record, + signal?: AbortSignal, + onResultChange?: onResultChange + ): Promise { + const apiPath = '/chat/completions' + const response = await this.post(`${this.options.arkBaseURL}${apiPath}`, this.getHeaders(), requestBody, signal) + let result = '' + await this.handleSSE(response, (message) => { + if (message === '[DONE]') { + return + } + const data = JSON.parse(message) + if (data.error) { + throw new ApiError(`Error from OpenAI: ${JSON.stringify(data)}`) + } + const text = data.choices[0]?.delta?.content + if (text !== undefined) { + result += text + if (onResultChange) { + onResultChange(result) + } + } + }) + return result + } + + async requestChatCompletionsNotStream( + requestBody: Record, + signal?: AbortSignal, + onResultChange?: onResultChange + ): Promise { + const apiPath = '/chat/completions' + const response = await this.post(`${this.options.arkBaseURL}${apiPath}`, this.getHeaders(), requestBody, signal) + const json = await response.json() + if (json.error) { + throw new ApiError(`Error from OpenAI: ${JSON.stringify(json)}`) + } + if (onResultChange) { + onResultChange(json.choices[0].message.content) + } + return json.choices[0].message.content + } + + getHeaders() { + const headers: Record = { + Authorization: `Bearer ${this.options.arkApiKey}`, + 'Content-Type': 'application/json', + } + return headers + } +} + +export const arkModelConfigs = { + 'doubao-1.5-pro-256k': { + maxTokens: 12_288, + maxContextTokens: 131_072, + }, + 'doubao-1.5-pro-32k': { + maxTokens: 12_288, + maxContextTokens: 32_768, + }, + 'doubao-1.5-vision-pro-32k': { + maxTokens: 12_288, + maxContextTokens: 32_768, + }, + 'deepseek-r1': { + maxTokens: 8192, + maxContextTokens: 64_000, + }, + 'deepseek-v3': { + maxTokens: 8192, + maxContextTokens: 64_000, + }, +} +export type ArkModel = keyof typeof arkModelConfigs +export const models = Array.from(Object.keys(arkModelConfigs)).sort() as ArkModel[] + +export async function populateGPTMessage(rawMessages: Message[]): Promise { + const messages: OpenAIMessage[] = rawMessages.map((m) => ({ + role: m.role, + content: m.content, + })) + return messages +} + +export function injectModelSystemPrompt(messages: Message[]) { + const metadataPrompt = ` +Current date: ${new Date().toISOString()} + +` + let hasInjected = false + return messages.map((m) => { + if (m.role === 'system' && !hasInjected) { + m = { ...m } + m.content = metadataPrompt + m.content + hasInjected = true + } + return m + }) +} + +export interface OpenAIMessage { + role: 'system' | 'user' | 'assistant' + content: string + name?: string +} diff --git a/src/renderer/packages/models/index.ts b/src/renderer/packages/models/index.ts index 6645ad9cc..9cf498624 100644 --- a/src/renderer/packages/models/index.ts +++ b/src/renderer/packages/models/index.ts @@ -6,7 +6,7 @@ import SiliconFlow from './siliconflow' import LMStudio from './lmstudio' import Claude from './claude' import PPIO from './ppio' - +import VolcArk from './ark' export function getModel(setting: Settings, config: Config) { switch (setting.aiProvider) { @@ -24,6 +24,8 @@ export function getModel(setting: Settings, config: Config) { return new SiliconFlow(setting) case ModelProvider.PPIO: return new PPIO(setting) + case ModelProvider.VolcengineArk: + return new VolcArk(setting) default: throw new Error('Cannot find model with provider: ' + setting.aiProvider) } @@ -37,6 +39,7 @@ export const aiProviderNameHash = { [ModelProvider.Ollama]: 'Ollama', [ModelProvider.SiliconFlow]: 'SiliconCloud API', [ModelProvider.PPIO]: 'PPIO', + [ModelProvider.VolcengineArk]: 'Volcengine Ark', } export const AIModelProviderMenuOptionList = [ @@ -76,6 +79,11 @@ export const AIModelProviderMenuOptionList = [ label: aiProviderNameHash[ModelProvider.PPIO], disabled: false, }, + { + value: ModelProvider.VolcengineArk, + label: aiProviderNameHash[ModelProvider.VolcengineArk], + disabled: false, + }, ] export function getModelDisplayName(settings: Settings, sessionType: SessionType): string { @@ -105,6 +113,8 @@ export function getModelDisplayName(settings: Settings, sessionType: SessionType return `SiliconCloud (${settings.siliconCloudModel})` case ModelProvider.PPIO: return `PPIO (${settings.ppioModel})` + case ModelProvider.VolcengineArk: + return `Ark (${settings.arkModel})` default: return 'unknown' } diff --git a/src/renderer/pages/SettingDialog/ArkSetting.tsx b/src/renderer/pages/SettingDialog/ArkSetting.tsx new file mode 100644 index 000000000..cf8bcf0a6 --- /dev/null +++ b/src/renderer/pages/SettingDialog/ArkSetting.tsx @@ -0,0 +1,79 @@ +import { Typography, Box } from '@mui/material' +import { ModelSettings } from '../../../shared/types' +import { useTranslation } from 'react-i18next' +import { Accordion, AccordionSummary, AccordionDetails } from '../../components/Accordion' +import TemperatureSlider from '../../components/TemperatureSlider' +import TopPSlider from '../../components/TopPSlider' +import PasswordTextField from '../../components/PasswordTextField' +import MaxContextMessageCountSlider from '../../components/MaxContextMessageCountSlider' +import ArkModelSelect from '../../components/ArkModelSelect' +import TextFieldReset from '@/components/TextFieldReset' + +interface ModelConfigProps { + settingsEdit: ModelSettings + setSettingsEdit: (settings: ModelSettings) => void +} + +export default function VolcArkSetting(props: ModelConfigProps) { + const { settingsEdit, setSettingsEdit } = props + const { t } = useTranslation() + return ( + + { + setSettingsEdit({ ...settingsEdit, arkApiKey: value }) + }} + placeholder="xxxxxxxxxxxxxxxxxxxxxxxx" + /> + <> + { + value = value.trim() + if (value.length > 4 && !value.startsWith('http')) { + value = 'https://' + value + } + setSettingsEdit({ ...settingsEdit, arkBaseURL: value }) + }} + /> + + + setSettingsEdit({ ...settingsEdit, arkModel, arkEndpointId }) + } + /> + + + + {t('token')}{' '} + + + + setSettingsEdit({ ...settingsEdit, temperature: value })} + /> + setSettingsEdit({ ...settingsEdit, topP: v })} + /> + setSettingsEdit({ ...settingsEdit, openaiMaxContextMessageCount: v })} + /> + + + + ) +} diff --git a/src/renderer/pages/SettingDialog/ModelSettingTab.tsx b/src/renderer/pages/SettingDialog/ModelSettingTab.tsx index 34e5f5c80..9f426bc1a 100644 --- a/src/renderer/pages/SettingDialog/ModelSettingTab.tsx +++ b/src/renderer/pages/SettingDialog/ModelSettingTab.tsx @@ -10,6 +10,7 @@ import MaxContextMessageCountSlider from '@/components/MaxContextMessageCountSli import TemperatureSlider from '@/components/TemperatureSlider' import ClaudeSetting from './ClaudeSetting' import PPIOSetting from './PPIOSetting' +import VolcArkSetting from './ArkSetting' interface ModelConfigProps { settingsEdit: ModelSettings @@ -76,7 +77,7 @@ export default function ModelSettingTab(props: ModelConfigProps) { )} - {settingsEdit.aiProvider === ModelProvider.SiliconFlow && ( + {settingsEdit.aiProvider === ModelProvider.SiliconFlow && ( )} {settingsEdit.aiProvider === ModelProvider.Claude && ( @@ -85,6 +86,9 @@ export default function ModelSettingTab(props: ModelConfigProps) { {settingsEdit.aiProvider === ModelProvider.PPIO && ( )} + {settingsEdit.aiProvider === ModelProvider.VolcengineArk && ( + + )} ) } diff --git a/src/shared/defaults.ts b/src/shared/defaults.ts index b4496f75f..06cd15669 100644 --- a/src/shared/defaults.ts +++ b/src/shared/defaults.ts @@ -56,6 +56,11 @@ export function settings(): Settings { ppioKey: '', ppioModel: 'deepseek/deepseek-r1/community', + arkApiKey: '', + arkBaseURL: 'https://ark.cn-beijing.volces.com/api/v3', + arkModel: 'doubao-1.5-pro-32k', + arkEndpointId: '', + autoGenerateTitle: true, } } diff --git a/src/shared/types.ts b/src/shared/types.ts index bd7b3b0ad..3a7ba630b 100644 --- a/src/shared/types.ts +++ b/src/shared/types.ts @@ -2,6 +2,7 @@ import { v4 as uuidv4 } from 'uuid' import { Model } from '../renderer/packages/models/openai' import * as siliconflow from '../renderer/packages/models/siliconflow' import { ClaudeModel } from '../renderer/packages/models/claude' +import { ArkModel } from '../renderer/packages/models/ark' export const MessageRoleEnum = { System: 'system', @@ -70,6 +71,7 @@ export enum ModelProvider { SiliconFlow = 'silicon-flow', LMStudio = 'lm-studio', PPIO = 'ppio', + VolcengineArk = 'volc-ark', } export interface ModelSettings { @@ -121,6 +123,12 @@ export interface ModelSettings { ppioKey: string ppioModel: string + // ark + arkApiKey: string + arkBaseURL: string + arkModel: ArkModel | 'custom-model' + arkEndpointId: string + temperature: number topP: number openaiMaxContextMessageCount: number