diff --git a/app/client/api.ts b/app/client/api.ts index aab19a630..b10e66495 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -7,7 +7,7 @@ import { ServiceProvider, } from "../constant"; import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; -import { ChatGPTApi } from "./platforms/openai"; +import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; import { ErnieApi } from "./platforms/baidu"; @@ -49,6 +49,7 @@ export interface LLMConfig { stream?: boolean; presence_penalty?: number; frequency_penalty?: number; + size?: DalleRequestPayload["size"]; } export interface ChatOptions { @@ -72,12 +73,14 @@ export interface LLMModel { describe: string; available: boolean; provider: LLMModelProvider; + sorted: number; } export interface LLMModelProvider { id: string; providerName: string; providerType: string; + sorted: number; } export abstract class LLMApi { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 973304e64..404f8c0f9 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -12,8 +12,13 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; -import { preProcessImageContent } from "@/app/utils/chat"; +import { + preProcessImageContent, + uploadImage, + base64Image2Blob, +} from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; +import { DalleSize } from "@/app/typing"; import { ChatOptions, @@ -34,6 +39,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3 as _isDalle3, } from "@/app/utils"; export interface OpenAIListModelResponse { @@ -59,6 +65,14 @@ export interface RequestPayload { max_tokens?: number; } +export interface DalleRequestPayload { + model: string; + prompt: string; + response_format: "url" | "b64_json"; + n: number; + size: DalleSize; +} + export class ChatGPTApi implements LLMApi { private disableListModels = true; @@ -101,20 +115,31 @@ export class ChatGPTApi implements LLMApi { return cloudflareAIGatewayUrl([baseUrl, path].join("/")); } - extractMessage(res: any) { - return res.choices?.at(0)?.message?.content ?? ""; + async extractMessage(res: any) { + if (res.error) { + return "```\n" + JSON.stringify(res, null, 4) + "\n```"; + } + // dalle3 model return url, using url create image message + if (res.data) { + let url = res.data?.at(0)?.url ?? ""; + const b64_json = res.data?.at(0)?.b64_json ?? ""; + if (!url && b64_json) { + // uploadImage + url = await uploadImage(base64Image2Blob(b64_json, "image/png")); + } + return [ + { + type: "image_url", + image_url: { + url, + }, + }, + ]; + } + return res.choices?.at(0)?.message?.content ?? res; } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); - const messages: ChatOptions["messages"] = []; - for (const v of options.messages) { - const content = visionModel - ? await preProcessImageContent(v.content) - : getMessageTextContent(v); - messages.push({ role: v.role, content }); - } - const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -123,28 +148,53 @@ export class ChatGPTApi implements LLMApi { providerName: options.config.providerName, }, }; - console.log('-------', modelConfig, options) - const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - // max_tokens: Math.max(modelConfig.max_tokens, 1024), - // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. - }; - // console.log("[Request] openai payload: ", requestPayload); - // add max_tokens to vision model - if (visionModel && modelConfig.model.includes("preview")) { - requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + let requestPayload: RequestPayload | DalleRequestPayload; + + const isDalle3 = _isDalle3(options.config.model); + if (isDalle3) { + const prompt = getMessageTextContent( + options.messages.slice(-1)?.pop() as any, + ); + requestPayload = { + model: options.config.model, + prompt, + // URLs are only valid for 60 minutes after the image has been generated. + response_format: "b64_json", // using b64_json, and save image in CacheStorage + n: 1, + size: options.config?.size ?? "1024x1024", + }; + } else { + const visionModel = isVisionModel(options.config.model); + const messages: ChatOptions["messages"] = []; + for (const v of options.messages) { + const content = visionModel + ? await preProcessImageContent(v.content) + : getMessageTextContent(v); + messages.push({ role: v.role, content }); + } + + requestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + // max_tokens: Math.max(modelConfig.max_tokens, 1024), + // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. + }; + + // add max_tokens to vision model + if (visionModel && modelConfig.model.includes("preview")) { + requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + } } console.log("[Request] openai payload: ", requestPayload); - const shouldStream = !!options.config.stream; + const shouldStream = !isDalle3 && !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); @@ -170,13 +220,15 @@ export class ChatGPTApi implements LLMApi { model?.provider?.providerName === ServiceProvider.Azure, ); chatPath = this.path( - Azure.ChatPath( + (isDalle3 ? Azure.ImagePath : Azure.ChatPath)( (model?.displayName ?? model?.name) as string, useCustomConfig ? useAccessStore.getState().azureApiVersion : "", ), ); } else { - chatPath = this.path(OpenaiPath.ChatPath); + chatPath = this.path( + isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath, + ); } // console.log('333333', chatPath) const chatPayload = { @@ -188,7 +240,7 @@ export class ChatGPTApi implements LLMApi { // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), - REQUEST_TIMEOUT_MS, + isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow. ); if (shouldStream) { @@ -325,7 +377,7 @@ export class ChatGPTApi implements LLMApi { clearTimeout(requestTimeoutId); const resJson = await res.json(); - const message = this.extractMessage(resJson); + const message = await this.extractMessage(resJson); options.onFinish(message); } } catch (e) { @@ -419,13 +471,17 @@ export class ChatGPTApi implements LLMApi { return []; } + //由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场 + let seq = 1000; //同 Constant.ts 中的排序保持一致 return chatModels.map((m) => ({ name: m.id, available: true, + sorted: seq++, provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, }, describe: "", })); diff --git a/app/components/chat.tsx b/app/components/chat.tsx index d0d1cbc8a..df24d550a 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg"; import BottomIcon from "../icons/bottom.svg"; import StopIcon from "../icons/pause.svg"; import RobotIcon from "../icons/robot.svg"; +import SizeIcon from "../icons/size.svg"; import PluginIcon from "../icons/plugin.svg"; // import UploadIcon from "../icons/upload.svg"; @@ -63,6 +64,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -70,6 +72,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; +import { DalleSize } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -505,6 +508,11 @@ export function ChatActions(props: { const [showUploadImage, setShowUploadImage] = useState(false); const current_day_token = localStorage.getItem("current_day_token") ?? ""; + const [showSizeSelector, setShowSizeSelector] = useState(false); + const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; + const currentSize = + chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; + useEffect(() => { const show = isVisionModel(currentModel); setShowUploadImage(show); @@ -651,6 +659,33 @@ export function ChatActions(props: { /> )} + {isDalle3(currentModel) && ( + setShowSizeSelector(true)} + text={currentSize} + icon={} + /> + )} + + {showSizeSelector && ( + ({ + title: m, + value: m, + }))} + onClose={() => setShowSizeSelector(false)} + onSelection={(s) => { + if (s.length === 0) return; + const size = s[0]; + chatStore.updateCurrentSession((session) => { + session.mask.modelConfig.size = size; + }); + showToast(size); + }} + /> + )} + setShowPluginSelector(true)} text={Locale.Plugin.Name} diff --git a/app/constant.ts b/app/constant.ts index 7e15a1def..2cdc2434d 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -147,9 +147,7 @@ export const Anthropic = { export const OpenaiPath = { ChatPath: "v1/chat/completions", - // Azure32kPath: - // "openai/deployments/gpt-4-32k/chat/completions?api-version=2023-05-15", - // Azure32kPathCheck: "openai/deployments/gpt-4-32k/chat/completions", + ImagePath: "v1/images/generations", UsagePath: "dashboard/billing/usage", SubsPath: "dashboard/billing/subscription", ListModelPath: "v1/models", @@ -158,7 +156,10 @@ export const OpenaiPath = { export const Azure = { ChatPath: (deployName: string, apiVersion: string) => `deployments/${deployName}/chat/completions?api-version=${apiVersion}`, - ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", + // https://.openai.azure.com/openai/deployments//images/generations?api-version= + ImagePath: (deployName: string, apiVersion: string) => + `deployments/${deployName}/images/generations?api-version=${apiVersion}`, + ExampleEndpoint: "https://{resource-url}/openai", }; export const Google = { @@ -261,6 +262,7 @@ const openaiModels = [ "gpt-4-vision-preview", "gpt-4-turbo-2024-04-09", "gpt-4-1106-preview", + "dall-e-3", ]; const googleModels = [ @@ -325,6 +327,7 @@ const tencentModels = [ const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]; +let seq = 1000; // 内置的模型序号生成器从1000开始 export const DEFAULT_MODELS = [ { name: "gpt-3.5-turbo", @@ -406,24 +409,6 @@ export const DEFAULT_MODELS = [ providerType: "openai", }, }, - // ...tencentModels.map((name) => ({ - // name, - // available: true, - // provider: { - // id: "tencent", - // providerName: "Tencent", - // providerType: "tencent", - // }, - // })), - // ...moonshotModes.map((name) => ({ - // name, - // available: true, - // provider: { - // id: "moonshot", - // providerName: "Moonshot", - // providerType: "moonshot", - // }, - // })), ] as const; // export const AZURE_MODELS: string[] = [ diff --git a/app/icons/size.svg b/app/icons/size.svg new file mode 100644 index 000000000..3da4fadfe --- /dev/null +++ b/app/icons/size.svg @@ -0,0 +1 @@ + diff --git a/app/store/chat.ts b/app/store/chat.ts index 9ead183ca..5d402c4aa 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -31,6 +31,7 @@ import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; import { collectModelsWithDefaultModel } from "../utils/model"; import { useAccessStore } from "./access"; +import { isDalle3 } from "../utils"; export type ChatMessage = RequestMessage & { date: string; @@ -95,12 +96,12 @@ function createEmptySession(): ChatSession { }; } - // if it is using gpt-* models, force to use 4o-mini to summarize +// if it is using gpt-* models, force to use 4o-mini to summarize const ChatFetchTaskPool: Record = {}; function getSummarizeModel(currentModel: string): { - name: string, - providerName: string | undefined, + name: string; + providerName: string | undefined; } { // if it is using gpt-* models, force to use 3.5 to summarize if (currentModel.startsWith("gpt")) { @@ -117,18 +118,18 @@ function getSummarizeModel(currentModel: string): { return { name: summarizeModel?.name ?? currentModel, providerName: summarizeModel?.provider?.providerName, - } + }; } if (currentModel.startsWith("gemini")) { return { name: GEMINI_SUMMARIZE_MODEL, providerName: ServiceProvider.Google, - } + }; } return { name: currentModel, providerName: undefined, - } + }; } function countMessages(msgs: ChatMessage[]) { @@ -718,7 +719,7 @@ export const useChatStore = createPersistStore( set(() => ({})); extAttr?.setAutoScroll(true); } else { - const api: ClientApi = getClientApi(modelConfig.providerName) + const api: ClientApi = getClientApi(modelConfig.providerName); // console.log('-------', modelConfig, '-----', api) // make request @@ -896,8 +897,13 @@ export const useChatStore = createPersistStore( const config = useAppConfig.getState(); const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + // skip summarize when using dalle3? + if (isDalle3(modelConfig.model)) { + return; + } - const api: ClientApi = getClientApi(modelConfig.providerName); + const providerName = modelConfig.providerName; + const api: ClientApi = getClientApi(providerName); // remove error messages if any const messages = session.messages; @@ -919,8 +925,10 @@ export const useChatStore = createPersistStore( messages: topicMessages, config: { model: getSummarizeModel(session.mask.modelConfig.model).name, - providerName: getSummarizeModel(session.mask.modelConfig.model).providerName, + providerName: getSummarizeModel(session.mask.modelConfig.model) + .providerName, stream: false, + providerName, }, onFinish(message) { get().updateCurrentSession( @@ -982,7 +990,8 @@ export const useChatStore = createPersistStore( ...modelcfg, stream: true, model: getSummarizeModel(session.mask.modelConfig.model).name, - providerName: getSummarizeModel(session.mask.modelConfig.model).providerName, + providerName: getSummarizeModel(session.mask.modelConfig.model) + .providerName, }, onUpdate(message) { session.memoryPrompt = message; diff --git a/app/store/config.ts b/app/store/config.ts index e8ca50502..db9d30e0b 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,4 +1,5 @@ import { LLMModel } from "../client/api"; +import { DalleSize } from "../typing"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, @@ -66,6 +67,7 @@ export const DEFAULT_CONFIG = { compressMessageLengthThreshold: 4000, enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, + size: "1024x1024" as DalleSize, }, }; diff --git a/app/typing.ts b/app/typing.ts index b09722ab9..863203581 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -7,3 +7,5 @@ export interface RequestMessage { role: MessageRole; content: string; } + +export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792"; diff --git a/app/utils.ts b/app/utils.ts index 68be8d1f7..2a2922907 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -266,3 +266,7 @@ export function isVisionModel(model: string) { visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo ); } + +export function isDalle3(model: string) { + return "dall-e-3" === model; +} diff --git a/app/utils/model.ts b/app/utils/model.ts index 773d4b80c..fc8be378c 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,12 +1,42 @@ import { DEFAULT_MODELS } from "../constant"; import { LLMModel } from "../client/api"; +const CustomSeq = { + val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts + cache: new Map(), + next: (id: string) => { + if (CustomSeq.cache.has(id)) { + return CustomSeq.cache.get(id) as number; + } else { + let seq = CustomSeq.val++; + CustomSeq.cache.set(id, seq); + return seq; + } + }, +}; + const customProvider = (providerName: string) => ({ id: providerName.toLowerCase(), providerName: providerName, providerType: "custom", + sorted: CustomSeq.next(providerName), }); +/** + * Sorts an array of models based on specified rules. + * + * First, sorted by provider; if the same, sorted by model + */ +const sortModelTable = (models: ReturnType) => + models.sort((a, b) => { + if (a.provider && b.provider) { + let cmp = a.provider.sorted - b.provider.sorted; + return cmp === 0 ? a.sorted - b.sorted : cmp; + } else { + return a.sorted - b.sorted; + } + }); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -17,6 +47,7 @@ export function collectModelTable( available: boolean; name: string; displayName: string; + sorted: number; describe: string; provider?: LLMModel["provider"]; // Marked as optional isDefault?: boolean; @@ -86,6 +117,7 @@ export function collectModelTable( available, describe: "", provider, // Use optional chaining + sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), }; } } @@ -140,7 +172,9 @@ export function collectModels( customModels: string, ) { const modelTable = collectModelTable(models, customModels); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels); return allModels; } @@ -155,7 +189,10 @@ export function collectModelsWithDefaultModel( customModels, defaultModel, ); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels); + return allModels; }