From ac599aa47c49bde2d557a6f7317347c940b29b17 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:00:42 +0800 Subject: [PATCH 01/10] add dalle3 model --- app/client/platforms/openai.ts | 90 ++++++++++++++++++++++++---------- app/components/chat.tsx | 34 +++++++++++++ app/constant.ts | 7 ++- app/store/chat.ts | 5 ++ app/utils.ts | 4 ++ 5 files changed, 113 insertions(+), 27 deletions(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 680125fe6..28de30051 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -33,6 +33,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3 as _isDalle3, } from "@/app/utils"; export interface OpenAIListModelResponse { @@ -58,6 +59,13 @@ export interface RequestPayload { max_tokens?: number; } +export interface DalleRequestPayload { + model: string; + prompt: string; + n: number; + size: "1024x1024" | "1792x1024" | "1024x1792"; +} + export class ChatGPTApi implements LLMApi { private disableListModels = true; @@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi { } extractMessage(res: any) { + if (res.error) { + return "```\n" + JSON.stringify(res, null, 4) + "\n```"; + } + // dalle3 model return url, just return + if (res.data) { + const url = res.data?.at(0)?.url ?? ""; + return [ + { + type: "image_url", + image_url: { + url, + }, + }, + ]; + } return res.choices?.at(0)?.message?.content ?? ""; } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); - const messages: ChatOptions["messages"] = []; - for (const v of options.messages) { - const content = visionModel - ? await preProcessImageContent(v.content) - : getMessageTextContent(v); - messages.push({ role: v.role, content }); - } - const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi { }, }; - const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - // max_tokens: Math.max(modelConfig.max_tokens, 1024), - // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. - }; + let requestPayload: RequestPayload | DalleRequestPayload; - // add max_tokens to vision model - if (visionModel && modelConfig.model.includes("preview")) { - requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + const isDalle3 = _isDalle3(options.config.model); + if (isDalle3) { + const prompt = getMessageTextContent(options.messages.slice(-1)?.pop()); + requestPayload = { + model: options.config.model, + prompt, + n: 1, + size: options.config?.size ?? "1024x1024", + }; + } else { + const visionModel = isVisionModel(options.config.model); + const messages: ChatOptions["messages"] = []; + for (const v of options.messages) { + const content = visionModel + ? await preProcessImageContent(v.content) + : getMessageTextContent(v); + messages.push({ role: v.role, content }); + } + + requestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + // max_tokens: Math.max(modelConfig.max_tokens, 1024), + // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. + }; + + // add max_tokens to vision model + if (visionModel && modelConfig.model.includes("preview")) { + requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + } } console.log("[Request] openai payload: ", requestPayload); - const shouldStream = !!options.config.stream; + const shouldStream = !isDalle3 && !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); @@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi { model?.provider?.providerName === ServiceProvider.Azure, ); chatPath = this.path( - Azure.ChatPath( + (isDalle3 ? Azure.ImagePath : Azure.ChatPath)( (model?.displayName ?? model?.name) as string, useCustomConfig ? useAccessStore.getState().azureApiVersion : "", ), ); } else { - chatPath = this.path(OpenaiPath.ChatPath); + chatPath = this.path( + isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath, + ); } const chatPayload = { method: "POST", diff --git a/app/components/chat.tsx b/app/components/chat.tsx index bb4b611ad..b95e85d45 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg"; import BottomIcon from "../icons/bottom.svg"; import StopIcon from "../icons/pause.svg"; import RobotIcon from "../icons/robot.svg"; +import SizeIcon from "../icons/size.svg"; import PluginIcon from "../icons/plugin.svg"; import { @@ -60,6 +61,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -481,6 +483,11 @@ export function ChatActions(props: { const [showPluginSelector, setShowPluginSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); + const [showSizeSelector, setShowSizeSelector] = useState(false); + const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"]; + const currentSize = + chatStore.currentSession().mask.modelConfig?.size || "1024x1024"; + useEffect(() => { const show = isVisionModel(currentModel); setShowUploadImage(show); @@ -624,6 +631,33 @@ export function ChatActions(props: { /> )} + {isDalle3(currentModel) && ( + setShowSizeSelector(true)} + text={currentSize} + icon={} + /> + )} + + {showSizeSelector && ( + ({ + title: m, + value: m, + }))} + onClose={() => setShowSizeSelector(false)} + onSelection={(s) => { + if (s.length === 0) return; + const size = s[0]; + chatStore.updateCurrentSession((session) => { + session.mask.modelConfig.size = size; + }); + showToast(size); + }} + /> + )} + setShowPluginSelector(true)} text={Locale.Plugin.Name} diff --git a/app/constant.ts b/app/constant.ts index 5251b5b4f..b777872c8 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -146,6 +146,7 @@ export const Anthropic = { export const OpenaiPath = { ChatPath: "v1/chat/completions", + ImagePath: "v1/images/generations", UsagePath: "dashboard/billing/usage", SubsPath: "dashboard/billing/subscription", ListModelPath: "v1/models", @@ -154,7 +155,10 @@ export const OpenaiPath = { export const Azure = { ChatPath: (deployName: string, apiVersion: string) => `deployments/${deployName}/chat/completions?api-version=${apiVersion}`, - ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", + // https://.openai.azure.com/openai/deployments//images/generations?api-version= + ImagePath: (deployName: string, apiVersion: string) => + `deployments/${deployName}/images/generations?api-version=${apiVersion}`, + ExampleEndpoint: "https://{resource-url}/openai", }; export const Google = { @@ -256,6 +260,7 @@ const openaiModels = [ "gpt-4-vision-preview", "gpt-4-turbo-2024-04-09", "gpt-4-1106-preview", + "dall-e-3", ]; const googleModels = [ diff --git a/app/store/chat.ts b/app/store/chat.ts index 5892ef0c8..7b47f3ec6 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -26,6 +26,7 @@ import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; import { collectModelsWithDefaultModel } from "../utils/model"; import { useAccessStore } from "./access"; +import { isDalle3 } from "../utils"; export type ChatMessage = RequestMessage & { date: string; @@ -541,6 +542,10 @@ export const useChatStore = createPersistStore( const config = useAppConfig.getState(); const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + // skip summarize when using dalle3? + if (isDalle3(modelConfig.model)) { + return; + } const api: ClientApi = getClientApi(modelConfig.providerName); diff --git a/app/utils.ts b/app/utils.ts index 2f2c8ae95..a3c329b82 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -265,3 +265,7 @@ export function isVisionModel(model: string) { visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo ); } + +export function isDalle3(model: string) { + return "dall-e-3" === model; +} From 1c24ca58c784775fb0d2cf9daa07949d329bd36a Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:03:19 +0800 Subject: [PATCH 02/10] add dalle3 model --- app/icons/size.svg | 1 + 1 file changed, 1 insertion(+) create mode 100644 app/icons/size.svg diff --git a/app/icons/size.svg b/app/icons/size.svg new file mode 100644 index 000000000..3da4fadfe --- /dev/null +++ b/app/icons/size.svg @@ -0,0 +1 @@ + From 46cb48023e6b2ffa52a44775b58a83a97dcffac2 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:50:48 +0800 Subject: [PATCH 03/10] fix typescript error --- app/client/api.ts | 3 ++- app/client/platforms/openai.ts | 7 +++++-- app/components/chat.tsx | 5 +++-- app/store/config.ts | 2 ++ app/typing.ts | 2 ++ 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index f10e47618..88157e79c 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -6,7 +6,7 @@ import { ServiceProvider, } from "../constant"; import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; -import { ChatGPTApi } from "./platforms/openai"; +import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; import { ErnieApi } from "./platforms/baidu"; @@ -42,6 +42,7 @@ export interface LLMConfig { stream?: boolean; presence_penalty?: number; frequency_penalty?: number; + size?: DalleRequestPayload["size"]; } export interface ChatOptions { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 28de30051..54309e29f 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -13,6 +13,7 @@ import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; import { preProcessImageContent } from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; +import { DalleSize } from "@/app/typing"; import { ChatOptions, @@ -63,7 +64,7 @@ export interface DalleRequestPayload { model: string; prompt: string; n: number; - size: "1024x1024" | "1792x1024" | "1024x1792"; + size: DalleSize; } export class ChatGPTApi implements LLMApi { @@ -141,7 +142,9 @@ export class ChatGPTApi implements LLMApi { const isDalle3 = _isDalle3(options.config.model); if (isDalle3) { - const prompt = getMessageTextContent(options.messages.slice(-1)?.pop()); + const prompt = getMessageTextContent( + options.messages.slice(-1)?.pop() as any, + ); requestPayload = { model: options.config.model, prompt, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index b95e85d45..67ea80c4a 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -69,6 +69,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; +import { DalleSize } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -484,9 +485,9 @@ export function ChatActions(props: { const [showUploadImage, setShowUploadImage] = useState(false); const [showSizeSelector, setShowSizeSelector] = useState(false); - const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"]; + const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; const currentSize = - chatStore.currentSession().mask.modelConfig?.size || "1024x1024"; + chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; useEffect(() => { const show = isVisionModel(currentModel); diff --git a/app/store/config.ts b/app/store/config.ts index 1eaafe12b..705a9d87c 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,4 +1,5 @@ import { LLMModel } from "../client/api"; +import { DalleSize } from "../typing"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, @@ -60,6 +61,7 @@ export const DEFAULT_CONFIG = { compressMessageLengthThreshold: 1000, enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, + size: "1024x1024" as DalleSize, }, }; diff --git a/app/typing.ts b/app/typing.ts index b09722ab9..863203581 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -7,3 +7,5 @@ export interface RequestMessage { role: MessageRole; content: string; } + +export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792"; From 8c83fe23a1661d37644626e8d71130d96ce413f9 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 20:58:21 +0800 Subject: [PATCH 04/10] using b64_json for dall-e-3 --- app/client/platforms/openai.ts | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 54309e29f..ee9a70913 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -11,7 +11,11 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; -import { preProcessImageContent } from "@/app/utils/chat"; +import { + preProcessImageContent, + uploadImage, + base64Image2Blob, +} from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; import { DalleSize } from "@/app/typing"; @@ -63,6 +67,7 @@ export interface RequestPayload { export interface DalleRequestPayload { model: string; prompt: string; + response_format: "url" | "b64_json"; n: number; size: DalleSize; } @@ -109,13 +114,18 @@ export class ChatGPTApi implements LLMApi { return cloudflareAIGatewayUrl([baseUrl, path].join("/")); } - extractMessage(res: any) { + async extractMessage(res: any) { if (res.error) { return "```\n" + JSON.stringify(res, null, 4) + "\n```"; } - // dalle3 model return url, just return + // dalle3 model return url, using url create image message if (res.data) { - const url = res.data?.at(0)?.url ?? ""; + let url = res.data?.at(0)?.url ?? ""; + const b64_json = res.data?.at(0)?.b64_json ?? ""; + if (!url && b64_json) { + // uploadImage + url = await uploadImage(base64Image2Blob(b64_json, "image/png")); + } return [ { type: "image_url", @@ -148,6 +158,8 @@ export class ChatGPTApi implements LLMApi { requestPayload = { model: options.config.model, prompt, + // URLs are only valid for 60 minutes after the image has been generated. + response_format: "b64_json", // using b64_json, and save image in CacheStorage n: 1, size: options.config?.size ?? "1024x1024", }; @@ -227,7 +239,7 @@ export class ChatGPTApi implements LLMApi { // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), - REQUEST_TIMEOUT_MS, + isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow. ); if (shouldStream) { @@ -358,7 +370,7 @@ export class ChatGPTApi implements LLMApi { clearTimeout(requestTimeoutId); const resJson = await res.json(); - const message = this.extractMessage(resJson); + const message = await this.extractMessage(resJson); options.onFinish(message); } } catch (e) { From 4a8e85c28a293c765ce73af6afb34aaa4840290e Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Fri, 2 Aug 2024 22:16:08 +0800 Subject: [PATCH 05/10] fix: empty response --- app/client/platforms/openai.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index ee9a70913..8b03d1397 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -135,7 +135,7 @@ export class ChatGPTApi implements LLMApi { }, ]; } - return res.choices?.at(0)?.message?.content ?? ""; + return res.choices?.at(0)?.message?.content ?? res; } async chat(options: ChatOptions) { From 8a4b8a84d67bb7431c5ce88046d94963dceebad7 Mon Sep 17 00:00:00 2001 From: frostime Date: Sat, 3 Aug 2024 17:16:05 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E2=9C=A8=20feat:=20=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8=EF=BC=8C=E5=B0=86=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B=E6=94=BE=E5=9C=A8=E5=89=8D?= =?UTF-8?q?=E9=9D=A2=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 4de0eb8d9..6b1485e32 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -22,15 +22,6 @@ export function collectModelTable( } > = {}; - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - // server custom models customModels .split(",") @@ -89,6 +80,15 @@ export function collectModelTable( } }); + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + return modelTable; } @@ -99,13 +99,16 @@ export function collectModelTableWithDefaultModel( ) { let modelTable = collectModelTable(models, customModels); if (defaultModel && defaultModel !== "") { - if (defaultModel.includes('@')) { + if (defaultModel.includes("@")) { if (defaultModel in modelTable) { modelTable[defaultModel].isDefault = true; } } else { for (const key of Object.keys(modelTable)) { - if (modelTable[key].available && key.split('@').shift() == defaultModel) { + if ( + modelTable[key].available && + key.split("@").shift() == defaultModel + ) { modelTable[key].isDefault = true; break; } From b023a00445682fcb336fe231ffe7c667632c0d15 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 16:37:22 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=F0=9F=94=A8=20refactor(model):=20?= =?UTF-8?q?=E6=9B=B4=E6=94=B9=E5=8E=9F=E5=85=88=E7=9A=84=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=96=B9=E6=B3=95=EF=BC=8C=E5=9C=A8=20collect=20table=20?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=90=8E=E9=9D=A2=E5=A2=9E=E5=8A=A0=E9=A2=9D?= =?UTF-8?q?=E5=A4=96=E7=9A=84=20sort=20=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 50 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 6b1485e32..b117b5eb6 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({ providerType: "custom", }); +const sortModelTable = ( + models: ReturnType, + rule: "custom-first" | "default-first", +) => + models.sort((a, b) => { + if (a.provider === undefined && b.provider === undefined) { + return 0; + } + + let aIsCustom = a.provider?.providerType === "custom"; + let bIsCustom = b.provider?.providerType === "custom"; + + if (aIsCustom === bIsCustom) { + return 0; + } + + if (aIsCustom) { + return rule === "custom-first" ? -1 : 1; + } else { + return rule === "custom-first" ? 1 : -1; + } + }); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -22,6 +45,15 @@ export function collectModelTable( } > = {}; + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + // server custom models customModels .split(",") @@ -80,15 +112,6 @@ export function collectModelTable( } }); - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - return modelTable; } @@ -126,7 +149,9 @@ export function collectModels( customModels: string, ) { const modelTable = collectModelTable(models, customModels); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); return allModels; } @@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel( customModels, defaultModel, ); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); + return allModels; } From 150fc84b9b55fe07da2fefa73b2cbee255d9de14 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 19:43:32 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E2=9C=A8=20feat(model):=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=20sorted=20=E5=AD=97=E6=AE=B5=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E8=AF=A5=E5=AD=97=E6=AE=B5=E5=AF=B9=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=88=97=E8=A1=A8=E8=BF=9B=E8=A1=8C=E6=8E=92=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 在 Model 和 Provider 类型中增加 sorted 字段(api.ts) 2. 默认模型在初始化的时候,自动设置默认 sorted 字段,从 1000 开始自增长(constant.ts) 3. 自定义模型更新的时候,自动分配 sorted 字段(model.ts) --- app/client/api.ts | 2 ++ app/constant.ts | 19 ++++++++++++++++++ app/utils/model.ts | 49 +++++++++++++++++++++++++++------------------- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index f10e47618..b13e0f8a4 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -64,12 +64,14 @@ export interface LLMModel { displayName?: string; available: boolean; provider: LLMModelProvider; + sorted: number; } export interface LLMModelProvider { id: string; providerName: string; providerType: string; + sorted: number; } export abstract class LLMApi { diff --git a/app/constant.ts b/app/constant.ts index 5251b5b4f..8ca17c4b3 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -320,86 +320,105 @@ const tencentModels = [ const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]; +let seq = 1000; // 内置的模型序号生成器从1000开始 export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, available: true, + sorted: seq++, // Global sequence sort(index) provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致 }, })), ...openaiModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "azure", providerName: "Azure", providerType: "azure", + sorted: 2, }, })), ...googleModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "google", providerName: "Google", providerType: "google", + sorted: 3, }, })), ...anthropicModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "anthropic", providerName: "Anthropic", providerType: "anthropic", + sorted: 4, }, })), ...baiduModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "baidu", providerName: "Baidu", providerType: "baidu", + sorted: 5, }, })), ...bytedanceModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "bytedance", providerName: "ByteDance", providerType: "bytedance", + sorted: 6, }, })), ...alibabaModes.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "alibaba", providerName: "Alibaba", providerType: "alibaba", + sorted: 7, }, })), ...tencentModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "tencent", providerName: "Tencent", providerType: "tencent", + sorted: 8, }, })), ...moonshotModes.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "moonshot", providerName: "Moonshot", providerType: "moonshot", + sorted: 9, }, })), ] as const; diff --git a/app/utils/model.ts b/app/utils/model.ts index b117b5eb6..0b62b53be 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,32 +1,39 @@ import { DEFAULT_MODELS } from "../constant"; import { LLMModel } from "../client/api"; +const CustomSeq = { + val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts + cache: new Map(), + next: (id: string) => { + if (CustomSeq.cache.has(id)) { + return CustomSeq.cache.get(id) as number; + } else { + let seq = CustomSeq.val++; + CustomSeq.cache.set(id, seq); + return seq; + } + }, +}; + const customProvider = (providerName: string) => ({ id: providerName.toLowerCase(), providerName: providerName, providerType: "custom", + sorted: CustomSeq.next(providerName), }); -const sortModelTable = ( - models: ReturnType, - rule: "custom-first" | "default-first", -) => +/** + * Sorts an array of models based on specified rules. + * + * First, sorted by provider; if the same, sorted by model + */ +const sortModelTable = (models: ReturnType) => models.sort((a, b) => { - if (a.provider === undefined && b.provider === undefined) { - return 0; - } - - let aIsCustom = a.provider?.providerType === "custom"; - let bIsCustom = b.provider?.providerType === "custom"; - - if (aIsCustom === bIsCustom) { - return 0; - } - - if (aIsCustom) { - return rule === "custom-first" ? -1 : 1; + if (a.provider && b.provider) { + let cmp = a.provider.sorted - b.provider.sorted; + return cmp === 0 ? a.sorted - b.sorted : cmp; } else { - return rule === "custom-first" ? 1 : -1; + return a.sorted - b.sorted; } }); @@ -40,6 +47,7 @@ export function collectModelTable( available: boolean; name: string; displayName: string; + sorted: number; provider?: LLMModel["provider"]; // Marked as optional isDefault?: boolean; } @@ -107,6 +115,7 @@ export function collectModelTable( displayName: displayName || customModelName, available, provider, // Use optional chaining + sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), }; } } @@ -151,7 +160,7 @@ export function collectModels( const modelTable = collectModelTable(models, customModels); let allModels = Object.values(modelTable); - allModels = sortModelTable(allModels, "custom-first"); + allModels = sortModelTable(allModels); return allModels; } @@ -168,7 +177,7 @@ export function collectModelsWithDefaultModel( ); let allModels = Object.values(modelTable); - allModels = sortModelTable(allModels, "custom-first"); + allModels = sortModelTable(allModels); return allModels; } From 3486954e073665b4bcaa4d41096b1341e4c497ff Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 20:26:48 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=F0=9F=90=9B=20fix(openai):=20=E4=B8=8A?= =?UTF-8?q?=E6=AC=A1=20commit=20=E5=90=8E=20openai.ts=20=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=AD=E5=87=BA=E7=8E=B0=E7=B1=BB=E5=9E=8B=E4=B8=8D=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=9A=84=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/client/platforms/openai.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 680125fe6..d95aebe87 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -411,13 +411,17 @@ export class ChatGPTApi implements LLMApi { return []; } + //由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场 + let seq = 1000; //同 Constant.ts 中的排序保持一致 return chatModels.map((m) => ({ name: m.id, available: true, + sorted: seq++, provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, }, })); } From 3da717d9fcb43134336d0105b8e794699edbf559 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Tue, 6 Aug 2024 11:20:03 +0800 Subject: [PATCH 10/10] fix: azure summary --- app/store/chat.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/store/chat.ts b/app/store/chat.ts index 7b47f3ec6..653926d1b 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -547,7 +547,8 @@ export const useChatStore = createPersistStore( return; } - const api: ClientApi = getClientApi(modelConfig.providerName); + const providerName = modelConfig.providerName; + const api: ClientApi = getClientApi(providerName); // remove error messages if any const messages = session.messages; @@ -570,6 +571,7 @@ export const useChatStore = createPersistStore( config: { model: getSummarizeModel(session.mask.modelConfig.model), stream: false, + providerName, }, onFinish(message) { get().updateCurrentSession(