mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-10-01 23:56:39 +08:00
Merge remote-tracking branch 'upstream/main' into dev
# Conflicts: # app/client/platforms/openai.ts # app/constant.ts # app/utils/model.ts
This commit is contained in:
commit
e0799f8f48
@ -7,7 +7,7 @@ import {
|
|||||||
ServiceProvider,
|
ServiceProvider,
|
||||||
} from "../constant";
|
} from "../constant";
|
||||||
import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
|
import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
|
||||||
import { ChatGPTApi } from "./platforms/openai";
|
import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai";
|
||||||
import { GeminiProApi } from "./platforms/google";
|
import { GeminiProApi } from "./platforms/google";
|
||||||
import { ClaudeApi } from "./platforms/anthropic";
|
import { ClaudeApi } from "./platforms/anthropic";
|
||||||
import { ErnieApi } from "./platforms/baidu";
|
import { ErnieApi } from "./platforms/baidu";
|
||||||
@ -49,6 +49,7 @@ export interface LLMConfig {
|
|||||||
stream?: boolean;
|
stream?: boolean;
|
||||||
presence_penalty?: number;
|
presence_penalty?: number;
|
||||||
frequency_penalty?: number;
|
frequency_penalty?: number;
|
||||||
|
size?: DalleRequestPayload["size"];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatOptions {
|
export interface ChatOptions {
|
||||||
@ -72,12 +73,14 @@ export interface LLMModel {
|
|||||||
describe: string;
|
describe: string;
|
||||||
available: boolean;
|
available: boolean;
|
||||||
provider: LLMModelProvider;
|
provider: LLMModelProvider;
|
||||||
|
sorted: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface LLMModelProvider {
|
export interface LLMModelProvider {
|
||||||
id: string;
|
id: string;
|
||||||
providerName: string;
|
providerName: string;
|
||||||
providerType: string;
|
providerType: string;
|
||||||
|
sorted: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export abstract class LLMApi {
|
export abstract class LLMApi {
|
||||||
|
@ -12,8 +12,13 @@ import {
|
|||||||
} from "@/app/constant";
|
} from "@/app/constant";
|
||||||
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
|
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
|
||||||
import { collectModelsWithDefaultModel } from "@/app/utils/model";
|
import { collectModelsWithDefaultModel } from "@/app/utils/model";
|
||||||
import { preProcessImageContent } from "@/app/utils/chat";
|
import {
|
||||||
|
preProcessImageContent,
|
||||||
|
uploadImage,
|
||||||
|
base64Image2Blob,
|
||||||
|
} from "@/app/utils/chat";
|
||||||
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
|
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
|
||||||
|
import { DalleSize } from "@/app/typing";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatOptions,
|
ChatOptions,
|
||||||
@ -34,6 +39,7 @@ import {
|
|||||||
getMessageTextContent,
|
getMessageTextContent,
|
||||||
getMessageImages,
|
getMessageImages,
|
||||||
isVisionModel,
|
isVisionModel,
|
||||||
|
isDalle3 as _isDalle3,
|
||||||
} from "@/app/utils";
|
} from "@/app/utils";
|
||||||
|
|
||||||
export interface OpenAIListModelResponse {
|
export interface OpenAIListModelResponse {
|
||||||
@ -59,6 +65,14 @@ export interface RequestPayload {
|
|||||||
max_tokens?: number;
|
max_tokens?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface DalleRequestPayload {
|
||||||
|
model: string;
|
||||||
|
prompt: string;
|
||||||
|
response_format: "url" | "b64_json";
|
||||||
|
n: number;
|
||||||
|
size: DalleSize;
|
||||||
|
}
|
||||||
|
|
||||||
export class ChatGPTApi implements LLMApi {
|
export class ChatGPTApi implements LLMApi {
|
||||||
private disableListModels = true;
|
private disableListModels = true;
|
||||||
|
|
||||||
@ -101,20 +115,31 @@ export class ChatGPTApi implements LLMApi {
|
|||||||
return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
|
return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
|
||||||
}
|
}
|
||||||
|
|
||||||
extractMessage(res: any) {
|
async extractMessage(res: any) {
|
||||||
return res.choices?.at(0)?.message?.content ?? "";
|
if (res.error) {
|
||||||
|
return "```\n" + JSON.stringify(res, null, 4) + "\n```";
|
||||||
|
}
|
||||||
|
// dalle3 model return url, using url create image message
|
||||||
|
if (res.data) {
|
||||||
|
let url = res.data?.at(0)?.url ?? "";
|
||||||
|
const b64_json = res.data?.at(0)?.b64_json ?? "";
|
||||||
|
if (!url && b64_json) {
|
||||||
|
// uploadImage
|
||||||
|
url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
type: "image_url",
|
||||||
|
image_url: {
|
||||||
|
url,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
}
|
||||||
|
return res.choices?.at(0)?.message?.content ?? res;
|
||||||
}
|
}
|
||||||
|
|
||||||
async chat(options: ChatOptions) {
|
async chat(options: ChatOptions) {
|
||||||
const visionModel = isVisionModel(options.config.model);
|
|
||||||
const messages: ChatOptions["messages"] = [];
|
|
||||||
for (const v of options.messages) {
|
|
||||||
const content = visionModel
|
|
||||||
? await preProcessImageContent(v.content)
|
|
||||||
: getMessageTextContent(v);
|
|
||||||
messages.push({ role: v.role, content });
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelConfig = {
|
const modelConfig = {
|
||||||
...useAppConfig.getState().modelConfig,
|
...useAppConfig.getState().modelConfig,
|
||||||
...useChatStore.getState().currentSession().mask.modelConfig,
|
...useChatStore.getState().currentSession().mask.modelConfig,
|
||||||
@ -123,28 +148,53 @@ export class ChatGPTApi implements LLMApi {
|
|||||||
providerName: options.config.providerName,
|
providerName: options.config.providerName,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
console.log('-------', modelConfig, options)
|
|
||||||
const requestPayload: RequestPayload = {
|
|
||||||
messages,
|
|
||||||
stream: options.config.stream,
|
|
||||||
model: modelConfig.model,
|
|
||||||
temperature: modelConfig.temperature,
|
|
||||||
presence_penalty: modelConfig.presence_penalty,
|
|
||||||
frequency_penalty: modelConfig.frequency_penalty,
|
|
||||||
top_p: modelConfig.top_p,
|
|
||||||
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
|
|
||||||
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
|
|
||||||
};
|
|
||||||
|
|
||||||
// console.log("[Request] openai payload: ", requestPayload);
|
let requestPayload: RequestPayload | DalleRequestPayload;
|
||||||
// add max_tokens to vision model
|
|
||||||
if (visionModel && modelConfig.model.includes("preview")) {
|
const isDalle3 = _isDalle3(options.config.model);
|
||||||
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
|
if (isDalle3) {
|
||||||
|
const prompt = getMessageTextContent(
|
||||||
|
options.messages.slice(-1)?.pop() as any,
|
||||||
|
);
|
||||||
|
requestPayload = {
|
||||||
|
model: options.config.model,
|
||||||
|
prompt,
|
||||||
|
// URLs are only valid for 60 minutes after the image has been generated.
|
||||||
|
response_format: "b64_json", // using b64_json, and save image in CacheStorage
|
||||||
|
n: 1,
|
||||||
|
size: options.config?.size ?? "1024x1024",
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
const visionModel = isVisionModel(options.config.model);
|
||||||
|
const messages: ChatOptions["messages"] = [];
|
||||||
|
for (const v of options.messages) {
|
||||||
|
const content = visionModel
|
||||||
|
? await preProcessImageContent(v.content)
|
||||||
|
: getMessageTextContent(v);
|
||||||
|
messages.push({ role: v.role, content });
|
||||||
|
}
|
||||||
|
|
||||||
|
requestPayload = {
|
||||||
|
messages,
|
||||||
|
stream: options.config.stream,
|
||||||
|
model: modelConfig.model,
|
||||||
|
temperature: modelConfig.temperature,
|
||||||
|
presence_penalty: modelConfig.presence_penalty,
|
||||||
|
frequency_penalty: modelConfig.frequency_penalty,
|
||||||
|
top_p: modelConfig.top_p,
|
||||||
|
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
|
||||||
|
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
|
||||||
|
};
|
||||||
|
|
||||||
|
// add max_tokens to vision model
|
||||||
|
if (visionModel && modelConfig.model.includes("preview")) {
|
||||||
|
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log("[Request] openai payload: ", requestPayload);
|
console.log("[Request] openai payload: ", requestPayload);
|
||||||
|
|
||||||
const shouldStream = !!options.config.stream;
|
const shouldStream = !isDalle3 && !!options.config.stream;
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
options.onController?.(controller);
|
options.onController?.(controller);
|
||||||
|
|
||||||
@ -170,13 +220,15 @@ export class ChatGPTApi implements LLMApi {
|
|||||||
model?.provider?.providerName === ServiceProvider.Azure,
|
model?.provider?.providerName === ServiceProvider.Azure,
|
||||||
);
|
);
|
||||||
chatPath = this.path(
|
chatPath = this.path(
|
||||||
Azure.ChatPath(
|
(isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
|
||||||
(model?.displayName ?? model?.name) as string,
|
(model?.displayName ?? model?.name) as string,
|
||||||
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
|
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
chatPath = this.path(OpenaiPath.ChatPath);
|
chatPath = this.path(
|
||||||
|
isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
// console.log('333333', chatPath)
|
// console.log('333333', chatPath)
|
||||||
const chatPayload = {
|
const chatPayload = {
|
||||||
@ -188,7 +240,7 @@ export class ChatGPTApi implements LLMApi {
|
|||||||
// make a fetch request
|
// make a fetch request
|
||||||
const requestTimeoutId = setTimeout(
|
const requestTimeoutId = setTimeout(
|
||||||
() => controller.abort(),
|
() => controller.abort(),
|
||||||
REQUEST_TIMEOUT_MS,
|
isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
|
||||||
);
|
);
|
||||||
|
|
||||||
if (shouldStream) {
|
if (shouldStream) {
|
||||||
@ -325,7 +377,7 @@ export class ChatGPTApi implements LLMApi {
|
|||||||
clearTimeout(requestTimeoutId);
|
clearTimeout(requestTimeoutId);
|
||||||
|
|
||||||
const resJson = await res.json();
|
const resJson = await res.json();
|
||||||
const message = this.extractMessage(resJson);
|
const message = await this.extractMessage(resJson);
|
||||||
options.onFinish(message);
|
options.onFinish(message);
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@ -419,13 +471,17 @@ export class ChatGPTApi implements LLMApi {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场
|
||||||
|
let seq = 1000; //同 Constant.ts 中的排序保持一致
|
||||||
return chatModels.map((m) => ({
|
return chatModels.map((m) => ({
|
||||||
name: m.id,
|
name: m.id,
|
||||||
available: true,
|
available: true,
|
||||||
|
sorted: seq++,
|
||||||
provider: {
|
provider: {
|
||||||
id: "openai",
|
id: "openai",
|
||||||
providerName: "OpenAI",
|
providerName: "OpenAI",
|
||||||
providerType: "openai",
|
providerType: "openai",
|
||||||
|
sorted: 1,
|
||||||
},
|
},
|
||||||
describe: "",
|
describe: "",
|
||||||
}));
|
}));
|
||||||
|
@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
|
|||||||
import BottomIcon from "../icons/bottom.svg";
|
import BottomIcon from "../icons/bottom.svg";
|
||||||
import StopIcon from "../icons/pause.svg";
|
import StopIcon from "../icons/pause.svg";
|
||||||
import RobotIcon from "../icons/robot.svg";
|
import RobotIcon from "../icons/robot.svg";
|
||||||
|
import SizeIcon from "../icons/size.svg";
|
||||||
import PluginIcon from "../icons/plugin.svg";
|
import PluginIcon from "../icons/plugin.svg";
|
||||||
// import UploadIcon from "../icons/upload.svg";
|
// import UploadIcon from "../icons/upload.svg";
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ import {
|
|||||||
getMessageTextContent,
|
getMessageTextContent,
|
||||||
getMessageImages,
|
getMessageImages,
|
||||||
isVisionModel,
|
isVisionModel,
|
||||||
|
isDalle3,
|
||||||
} from "../utils";
|
} from "../utils";
|
||||||
|
|
||||||
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
||||||
@ -70,6 +72,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
|
|||||||
import dynamic from "next/dynamic";
|
import dynamic from "next/dynamic";
|
||||||
|
|
||||||
import { ChatControllerPool } from "../client/controller";
|
import { ChatControllerPool } from "../client/controller";
|
||||||
|
import { DalleSize } from "../typing";
|
||||||
import { Prompt, usePromptStore } from "../store/prompt";
|
import { Prompt, usePromptStore } from "../store/prompt";
|
||||||
import Locale from "../locales";
|
import Locale from "../locales";
|
||||||
|
|
||||||
@ -505,6 +508,11 @@ export function ChatActions(props: {
|
|||||||
const [showUploadImage, setShowUploadImage] = useState(false);
|
const [showUploadImage, setShowUploadImage] = useState(false);
|
||||||
const current_day_token = localStorage.getItem("current_day_token") ?? "";
|
const current_day_token = localStorage.getItem("current_day_token") ?? "";
|
||||||
|
|
||||||
|
const [showSizeSelector, setShowSizeSelector] = useState(false);
|
||||||
|
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
|
||||||
|
const currentSize =
|
||||||
|
chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const show = isVisionModel(currentModel);
|
const show = isVisionModel(currentModel);
|
||||||
setShowUploadImage(show);
|
setShowUploadImage(show);
|
||||||
@ -651,6 +659,33 @@ export function ChatActions(props: {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{isDalle3(currentModel) && (
|
||||||
|
<ChatAction
|
||||||
|
onClick={() => setShowSizeSelector(true)}
|
||||||
|
text={currentSize}
|
||||||
|
icon={<SizeIcon />}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{showSizeSelector && (
|
||||||
|
<Selector
|
||||||
|
defaultSelectedValue={currentSize}
|
||||||
|
items={dalle3Sizes.map((m) => ({
|
||||||
|
title: m,
|
||||||
|
value: m,
|
||||||
|
}))}
|
||||||
|
onClose={() => setShowSizeSelector(false)}
|
||||||
|
onSelection={(s) => {
|
||||||
|
if (s.length === 0) return;
|
||||||
|
const size = s[0];
|
||||||
|
chatStore.updateCurrentSession((session) => {
|
||||||
|
session.mask.modelConfig.size = size;
|
||||||
|
});
|
||||||
|
showToast(size);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
<ChatAction
|
<ChatAction
|
||||||
onClick={() => setShowPluginSelector(true)}
|
onClick={() => setShowPluginSelector(true)}
|
||||||
text={Locale.Plugin.Name}
|
text={Locale.Plugin.Name}
|
||||||
|
@ -147,9 +147,7 @@ export const Anthropic = {
|
|||||||
|
|
||||||
export const OpenaiPath = {
|
export const OpenaiPath = {
|
||||||
ChatPath: "v1/chat/completions",
|
ChatPath: "v1/chat/completions",
|
||||||
// Azure32kPath:
|
ImagePath: "v1/images/generations",
|
||||||
// "openai/deployments/gpt-4-32k/chat/completions?api-version=2023-05-15",
|
|
||||||
// Azure32kPathCheck: "openai/deployments/gpt-4-32k/chat/completions",
|
|
||||||
UsagePath: "dashboard/billing/usage",
|
UsagePath: "dashboard/billing/usage",
|
||||||
SubsPath: "dashboard/billing/subscription",
|
SubsPath: "dashboard/billing/subscription",
|
||||||
ListModelPath: "v1/models",
|
ListModelPath: "v1/models",
|
||||||
@ -158,7 +156,10 @@ export const OpenaiPath = {
|
|||||||
export const Azure = {
|
export const Azure = {
|
||||||
ChatPath: (deployName: string, apiVersion: string) =>
|
ChatPath: (deployName: string, apiVersion: string) =>
|
||||||
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
|
`deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
|
||||||
ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
|
// https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
|
||||||
|
ImagePath: (deployName: string, apiVersion: string) =>
|
||||||
|
`deployments/${deployName}/images/generations?api-version=${apiVersion}`,
|
||||||
|
ExampleEndpoint: "https://{resource-url}/openai",
|
||||||
};
|
};
|
||||||
|
|
||||||
export const Google = {
|
export const Google = {
|
||||||
@ -261,6 +262,7 @@ const openaiModels = [
|
|||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
"gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4-1106-preview",
|
"gpt-4-1106-preview",
|
||||||
|
"dall-e-3",
|
||||||
];
|
];
|
||||||
|
|
||||||
const googleModels = [
|
const googleModels = [
|
||||||
@ -325,6 +327,7 @@ const tencentModels = [
|
|||||||
|
|
||||||
const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"];
|
const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"];
|
||||||
|
|
||||||
|
let seq = 1000; // 内置的模型序号生成器从1000开始
|
||||||
export const DEFAULT_MODELS = [
|
export const DEFAULT_MODELS = [
|
||||||
{
|
{
|
||||||
name: "gpt-3.5-turbo",
|
name: "gpt-3.5-turbo",
|
||||||
@ -406,24 +409,6 @@ export const DEFAULT_MODELS = [
|
|||||||
providerType: "openai",
|
providerType: "openai",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// ...tencentModels.map((name) => ({
|
|
||||||
// name,
|
|
||||||
// available: true,
|
|
||||||
// provider: {
|
|
||||||
// id: "tencent",
|
|
||||||
// providerName: "Tencent",
|
|
||||||
// providerType: "tencent",
|
|
||||||
// },
|
|
||||||
// })),
|
|
||||||
// ...moonshotModes.map((name) => ({
|
|
||||||
// name,
|
|
||||||
// available: true,
|
|
||||||
// provider: {
|
|
||||||
// id: "moonshot",
|
|
||||||
// providerName: "Moonshot",
|
|
||||||
// providerType: "moonshot",
|
|
||||||
// },
|
|
||||||
// })),
|
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
// export const AZURE_MODELS: string[] = [
|
// export const AZURE_MODELS: string[] = [
|
||||||
|
1
app/icons/size.svg
Normal file
1
app/icons/size.svg
Normal file
@ -0,0 +1 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?><svg width="16" height="16" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M42 7H6C4.89543 7 4 7.89543 4 9V39C4 40.1046 4.89543 41 6 41H42C43.1046 41 44 40.1046 44 39V9C44 7.89543 43.1046 7 42 7Z" fill="none" stroke="#333" stroke-width="4"/><path d="M30 30V18L38 30V18" stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/><path d="M10 30V18L18 30V18" stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/><path d="M24 20V21" stroke="#333" stroke-width="4" stroke-linecap="round"/><path d="M24 27V28" stroke="#333" stroke-width="4" stroke-linecap="round"/></svg>
|
After Width: | Height: | Size: 681 B |
@ -31,6 +31,7 @@ import { nanoid } from "nanoid";
|
|||||||
import { createPersistStore } from "../utils/store";
|
import { createPersistStore } from "../utils/store";
|
||||||
import { collectModelsWithDefaultModel } from "../utils/model";
|
import { collectModelsWithDefaultModel } from "../utils/model";
|
||||||
import { useAccessStore } from "./access";
|
import { useAccessStore } from "./access";
|
||||||
|
import { isDalle3 } from "../utils";
|
||||||
|
|
||||||
export type ChatMessage = RequestMessage & {
|
export type ChatMessage = RequestMessage & {
|
||||||
date: string;
|
date: string;
|
||||||
@ -95,12 +96,12 @@ function createEmptySession(): ChatSession {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it is using gpt-* models, force to use 4o-mini to summarize
|
// if it is using gpt-* models, force to use 4o-mini to summarize
|
||||||
const ChatFetchTaskPool: Record<string, any> = {};
|
const ChatFetchTaskPool: Record<string, any> = {};
|
||||||
|
|
||||||
function getSummarizeModel(currentModel: string): {
|
function getSummarizeModel(currentModel: string): {
|
||||||
name: string,
|
name: string;
|
||||||
providerName: string | undefined,
|
providerName: string | undefined;
|
||||||
} {
|
} {
|
||||||
// if it is using gpt-* models, force to use 3.5 to summarize
|
// if it is using gpt-* models, force to use 3.5 to summarize
|
||||||
if (currentModel.startsWith("gpt")) {
|
if (currentModel.startsWith("gpt")) {
|
||||||
@ -117,18 +118,18 @@ function getSummarizeModel(currentModel: string): {
|
|||||||
return {
|
return {
|
||||||
name: summarizeModel?.name ?? currentModel,
|
name: summarizeModel?.name ?? currentModel,
|
||||||
providerName: summarizeModel?.provider?.providerName,
|
providerName: summarizeModel?.provider?.providerName,
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
if (currentModel.startsWith("gemini")) {
|
if (currentModel.startsWith("gemini")) {
|
||||||
return {
|
return {
|
||||||
name: GEMINI_SUMMARIZE_MODEL,
|
name: GEMINI_SUMMARIZE_MODEL,
|
||||||
providerName: ServiceProvider.Google,
|
providerName: ServiceProvider.Google,
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
name: currentModel,
|
name: currentModel,
|
||||||
providerName: undefined,
|
providerName: undefined,
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function countMessages(msgs: ChatMessage[]) {
|
function countMessages(msgs: ChatMessage[]) {
|
||||||
@ -718,7 +719,7 @@ export const useChatStore = createPersistStore(
|
|||||||
set(() => ({}));
|
set(() => ({}));
|
||||||
extAttr?.setAutoScroll(true);
|
extAttr?.setAutoScroll(true);
|
||||||
} else {
|
} else {
|
||||||
const api: ClientApi = getClientApi(modelConfig.providerName)
|
const api: ClientApi = getClientApi(modelConfig.providerName);
|
||||||
// console.log('-------', modelConfig, '-----', api)
|
// console.log('-------', modelConfig, '-----', api)
|
||||||
|
|
||||||
// make request
|
// make request
|
||||||
@ -896,8 +897,13 @@ export const useChatStore = createPersistStore(
|
|||||||
const config = useAppConfig.getState();
|
const config = useAppConfig.getState();
|
||||||
const session = get().currentSession();
|
const session = get().currentSession();
|
||||||
const modelConfig = session.mask.modelConfig;
|
const modelConfig = session.mask.modelConfig;
|
||||||
|
// skip summarize when using dalle3?
|
||||||
|
if (isDalle3(modelConfig.model)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const api: ClientApi = getClientApi(modelConfig.providerName);
|
const providerName = modelConfig.providerName;
|
||||||
|
const api: ClientApi = getClientApi(providerName);
|
||||||
|
|
||||||
// remove error messages if any
|
// remove error messages if any
|
||||||
const messages = session.messages;
|
const messages = session.messages;
|
||||||
@ -919,8 +925,10 @@ export const useChatStore = createPersistStore(
|
|||||||
messages: topicMessages,
|
messages: topicMessages,
|
||||||
config: {
|
config: {
|
||||||
model: getSummarizeModel(session.mask.modelConfig.model).name,
|
model: getSummarizeModel(session.mask.modelConfig.model).name,
|
||||||
providerName: getSummarizeModel(session.mask.modelConfig.model).providerName,
|
providerName: getSummarizeModel(session.mask.modelConfig.model)
|
||||||
|
.providerName,
|
||||||
stream: false,
|
stream: false,
|
||||||
|
providerName,
|
||||||
},
|
},
|
||||||
onFinish(message) {
|
onFinish(message) {
|
||||||
get().updateCurrentSession(
|
get().updateCurrentSession(
|
||||||
@ -982,7 +990,8 @@ export const useChatStore = createPersistStore(
|
|||||||
...modelcfg,
|
...modelcfg,
|
||||||
stream: true,
|
stream: true,
|
||||||
model: getSummarizeModel(session.mask.modelConfig.model).name,
|
model: getSummarizeModel(session.mask.modelConfig.model).name,
|
||||||
providerName: getSummarizeModel(session.mask.modelConfig.model).providerName,
|
providerName: getSummarizeModel(session.mask.modelConfig.model)
|
||||||
|
.providerName,
|
||||||
},
|
},
|
||||||
onUpdate(message) {
|
onUpdate(message) {
|
||||||
session.memoryPrompt = message;
|
session.memoryPrompt = message;
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { LLMModel } from "../client/api";
|
import { LLMModel } from "../client/api";
|
||||||
|
import { DalleSize } from "../typing";
|
||||||
import { getClientConfig } from "../config/client";
|
import { getClientConfig } from "../config/client";
|
||||||
import {
|
import {
|
||||||
DEFAULT_INPUT_TEMPLATE,
|
DEFAULT_INPUT_TEMPLATE,
|
||||||
@ -66,6 +67,7 @@ export const DEFAULT_CONFIG = {
|
|||||||
compressMessageLengthThreshold: 4000,
|
compressMessageLengthThreshold: 4000,
|
||||||
enableInjectSystemPrompts: true,
|
enableInjectSystemPrompts: true,
|
||||||
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
|
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
|
||||||
|
size: "1024x1024" as DalleSize,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,3 +7,5 @@ export interface RequestMessage {
|
|||||||
role: MessageRole;
|
role: MessageRole;
|
||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
|
||||||
|
@ -266,3 +266,7 @@ export function isVisionModel(model: string) {
|
|||||||
visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
|
visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isDalle3(model: string) {
|
||||||
|
return "dall-e-3" === model;
|
||||||
|
}
|
||||||
|
@ -1,12 +1,42 @@
|
|||||||
import { DEFAULT_MODELS } from "../constant";
|
import { DEFAULT_MODELS } from "../constant";
|
||||||
import { LLMModel } from "../client/api";
|
import { LLMModel } from "../client/api";
|
||||||
|
|
||||||
|
const CustomSeq = {
|
||||||
|
val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
|
||||||
|
cache: new Map<string, number>(),
|
||||||
|
next: (id: string) => {
|
||||||
|
if (CustomSeq.cache.has(id)) {
|
||||||
|
return CustomSeq.cache.get(id) as number;
|
||||||
|
} else {
|
||||||
|
let seq = CustomSeq.val++;
|
||||||
|
CustomSeq.cache.set(id, seq);
|
||||||
|
return seq;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
const customProvider = (providerName: string) => ({
|
const customProvider = (providerName: string) => ({
|
||||||
id: providerName.toLowerCase(),
|
id: providerName.toLowerCase(),
|
||||||
providerName: providerName,
|
providerName: providerName,
|
||||||
providerType: "custom",
|
providerType: "custom",
|
||||||
|
sorted: CustomSeq.next(providerName),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sorts an array of models based on specified rules.
|
||||||
|
*
|
||||||
|
* First, sorted by provider; if the same, sorted by model
|
||||||
|
*/
|
||||||
|
const sortModelTable = (models: ReturnType<typeof collectModels>) =>
|
||||||
|
models.sort((a, b) => {
|
||||||
|
if (a.provider && b.provider) {
|
||||||
|
let cmp = a.provider.sorted - b.provider.sorted;
|
||||||
|
return cmp === 0 ? a.sorted - b.sorted : cmp;
|
||||||
|
} else {
|
||||||
|
return a.sorted - b.sorted;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
export function collectModelTable(
|
export function collectModelTable(
|
||||||
models: readonly LLMModel[],
|
models: readonly LLMModel[],
|
||||||
customModels: string,
|
customModels: string,
|
||||||
@ -17,6 +47,7 @@ export function collectModelTable(
|
|||||||
available: boolean;
|
available: boolean;
|
||||||
name: string;
|
name: string;
|
||||||
displayName: string;
|
displayName: string;
|
||||||
|
sorted: number;
|
||||||
describe: string;
|
describe: string;
|
||||||
provider?: LLMModel["provider"]; // Marked as optional
|
provider?: LLMModel["provider"]; // Marked as optional
|
||||||
isDefault?: boolean;
|
isDefault?: boolean;
|
||||||
@ -86,6 +117,7 @@ export function collectModelTable(
|
|||||||
available,
|
available,
|
||||||
describe: "",
|
describe: "",
|
||||||
provider, // Use optional chaining
|
provider, // Use optional chaining
|
||||||
|
sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -140,7 +172,9 @@ export function collectModels(
|
|||||||
customModels: string,
|
customModels: string,
|
||||||
) {
|
) {
|
||||||
const modelTable = collectModelTable(models, customModels);
|
const modelTable = collectModelTable(models, customModels);
|
||||||
const allModels = Object.values(modelTable);
|
let allModels = Object.values(modelTable);
|
||||||
|
|
||||||
|
allModels = sortModelTable(allModels);
|
||||||
|
|
||||||
return allModels;
|
return allModels;
|
||||||
}
|
}
|
||||||
@ -155,7 +189,10 @@ export function collectModelsWithDefaultModel(
|
|||||||
customModels,
|
customModels,
|
||||||
defaultModel,
|
defaultModel,
|
||||||
);
|
);
|
||||||
const allModels = Object.values(modelTable);
|
let allModels = Object.values(modelTable);
|
||||||
|
|
||||||
|
allModels = sortModelTable(allModels);
|
||||||
|
|
||||||
return allModels;
|
return allModels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user