diff --git a/app/api/config/route.ts b/app/api/config/route.ts index db84fba17..b0d9da031 100644 --- a/app/api/config/route.ts +++ b/app/api/config/route.ts @@ -13,6 +13,7 @@ const DANGER_CONFIG = { hideBalanceQuery: serverConfig.hideBalanceQuery, disableFastLink: serverConfig.disableFastLink, customModels: serverConfig.customModels, + defaultModel: serverConfig.defaultModel, }; declare global { diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index 1ab36db25..b6eb8d3df 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -21,11 +21,10 @@ export class GeminiProApi implements LLMApi { } async chat(options: ChatOptions): Promise { // const apiClient = this; - const visionModel = isVisionModel(options.config.model); let multimodal = false; const messages = options.messages.map((v) => { let parts: any[] = [{ text: getMessageTextContent(v) }]; - if (visionModel) { + if (isVisionModel(options.config.model)) { const images = getMessageImages(v); if (images.length > 0) { multimodal = true; @@ -117,17 +116,14 @@ export class GeminiProApi implements LLMApi { const controller = new AbortController(); options.onController?.(controller); try { - let googleChatPath = visionModel - ? Google.VisionChatPath(modelConfig.model) - : Google.ChatPath(modelConfig.model); - let chatPath = this.path(googleChatPath); - // let baseUrl = accessStore.googleUrl; if (!baseUrl) { baseUrl = isApp - ? DEFAULT_API_HOST + "/api/proxy/google/" + googleChatPath - : chatPath; + ? DEFAULT_API_HOST + + "/api/proxy/google/" + + Google.ChatPath(modelConfig.model) + : this.path(Google.ChatPath(modelConfig.model)); } if (isApp) { @@ -145,6 +141,7 @@ export class GeminiProApi implements LLMApi { () => controller.abort(), REQUEST_TIMEOUT_MS, ); + if (shouldStream) { let responseText = ""; let remainText = ""; diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 7862e9caa..76f2aeacf 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -135,7 +135,7 @@ export class ChatGPTApi implements LLMApi { // console.log("[Request] openai payload: ", requestPayload); // add max_tokens to vision model - if (visionModel) { + if (visionModel && modelConfig.model.includes("preview")) { requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); } diff --git a/app/components/chat.tsx b/app/components/chat.tsx index f890c6dae..0b227c7a3 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -468,10 +468,20 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; const allModels = useAllModels(); - const models = useMemo( - () => allModels.filter((m) => m.available), - [allModels], - ); + const models = useMemo(() => { + const filteredModels = allModels.filter((m) => m.available); + const defaultModel = filteredModels.find((m) => m.isDefault); + + if (defaultModel) { + const arr = [ + defaultModel, + ...filteredModels.filter((m) => m !== defaultModel), + ]; + return arr; + } else { + return filteredModels; + } + }, [allModels]); const [showModelSelector, setShowModelSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); const current_day_token = localStorage.getItem("current_day_token") ?? ""; @@ -486,9 +496,12 @@ export function ChatActions(props: { // if current model is not available // switch to first available model - const isUnavailableModel = !models.some((m) => m.name === currentModel); - if (isUnavailableModel && models.length > 0) { - const nextModel = models[0].name as ModelType; + const isUnavaliableModel = !models.some((m) => m.name === currentModel); + if (isUnavaliableModel && models.length > 0) { + // show next model to default model if exist + let nextModel: ModelType = ( + models.find((model) => model.isDefault) || models[0] + ).name; chatStore.updateCurrentSession( (session) => (session.mask.modelConfig.model = nextModel), ); diff --git a/app/config/server.ts b/app/config/server.ts index 37705ae2c..2f7d15bae 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -22,6 +22,7 @@ declare global { ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not DISABLE_FAST_LINK?: string; // disallow parse settings from url or not CUSTOM_MODELS?: string; // to control custom models + DEFAULT_MODEL?: string; // to cnntrol default model in every new chat window // azure only AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name} @@ -61,12 +62,14 @@ export const getServerSideConfig = () => { const disableGPT4 = !!process.env.DISABLE_GPT4; let customModels = process.env.CUSTOM_MODELS ?? ""; + let defaultModel = process.env.DEFAULT_MODEL ?? ""; if (disableGPT4) { if (customModels) customModels += ","; customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4")) .map((m) => "-" + m.name) .join(","); + if (defaultModel.startsWith("gpt-4")) defaultModel = ""; } // const isAzure = !!process.env.AZURE_URL; @@ -124,6 +127,7 @@ export const getServerSideConfig = () => { hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY, disableFastLink: !!process.env.DISABLE_FAST_LINK, customModels, + defaultModel, whiteWebDevEndpoints, }; }; diff --git a/app/constant.ts b/app/constant.ts index a85bc8432..932382679 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -103,8 +103,8 @@ export const Azure = { export const Google = { ExampleEndpoint: "https://generativelanguage.googleapis.com/", ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, - VisionChatPath: (modelName: string) => - `v1beta/models/${modelName}:generateContent`, + // VisionChatPath: (modelName: string) => + // `v1beta/models/${modelName}:generateContent`, }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -133,8 +133,6 @@ export const KnowledgeCutOffDate: Record = { "gpt-4-turbo": "2023-12", "gpt-4-turbo-2024-04-09": "2023-12", "gpt-4-turbo-preview": "2023-12", - "gpt-4-1106-preview": "2023-04", - "gpt-4-0125-preview": "2023-12", "gpt-4-vision-preview": "2023-04", // After improvements, // it's now easier to add "KnowledgeCutOffDate" instead of stupid hardcoding it, as was done previously. @@ -144,19 +142,11 @@ export const KnowledgeCutOffDate: Record = { const openaiModels = [ "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613", "gpt-4", - "gpt-4-0314", "gpt-4-0613", - "gpt-4-1106-preview", - "gpt-4-0125-preview", "gpt-4-32k", - "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-turbo", "gpt-4-turbo-preview", diff --git a/app/store/access.ts b/app/store/access.ts index b738a2201..1bdb502e5 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -8,6 +8,7 @@ import { getHeaders } from "../client/api"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import { ensure } from "../utils/clone"; +import { DEFAULT_CONFIG } from "./config"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done @@ -51,6 +52,7 @@ const DEFAULT_ACCESS_STATE = { useMjImgSelfProxy: false, disableFastLink: false, customModels: "", + defaultModel: "", }; export const useAccessStore = createPersistStore( @@ -105,6 +107,31 @@ export const useAccessStore = createPersistStore( }; set(() => ({ ...res })); fetchState = 2; // 设置 fetchState 值为 "获取已完成" + // fetch("/api/config", { + // method: "post", + // body: null, + // headers: { + // ...getHeaders(), + // }, + // }) + // .then((res) => res.json()) + // .then((res) => { + // // Set default model from env request + // let defaultModel = res.defaultModel ?? ""; + // DEFAULT_CONFIG.modelConfig.model = + // defaultModel !== "" ? defaultModel : "gpt-3.5-turbo"; + // return res; + // }) + // .then((res: DangerConfig) => { + // console.log("[Config] got config from server", res); + // set(() => ({ ...res })); + // }) + // .catch(() => { + // console.error("[Config] failed to fetch config"); + // }) + // .finally(() => { + // fetchState = 2; + // }); }, }), { diff --git a/app/styles/globals.scss b/app/styles/globals.scss index 539df3acb..1bdaf28cf 100644 --- a/app/styles/globals.scss +++ b/app/styles/globals.scss @@ -86,6 +86,7 @@ @include dark; } } + html { height: var(--full-height); @@ -110,6 +111,10 @@ body { @media only screen and (max-width: 600px) { background-color: var(--second); } + + *:focus-visible { + outline: none; + } } ::-webkit-scrollbar { diff --git a/app/utils/hooks.ts b/app/utils/hooks.ts index 684dfc4e8..a1f67096f 100644 --- a/app/utils/hooks.ts +++ b/app/utils/hooks.ts @@ -1,14 +1,15 @@ import { useMemo } from "react"; import { useAccessStore, useAppConfig } from "../store"; -import { collectModels } from "./model"; +import { collectModels, collectModelsWithDefaultModel } from "./model"; export function useAllModels() { const accessStore = useAccessStore(); const configStore = useAppConfig(); const models = useMemo(() => { - return collectModels( + return collectModelsWithDefaultModel( configStore.models, [configStore.customModels, accessStore.customModels].join(","), + accessStore.defaultModel, ).filter((m) => !configStore.dontUseModel.includes(m.name as any)); }, [ accessStore.customModels, diff --git a/app/utils/model.ts b/app/utils/model.ts index 7dc0e23da..5b700410d 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,5 +1,11 @@ import { LLMModel } from "../client/api"; +const customProvider = (modelName: string) => ({ + id: modelName, + providerName: "", + providerType: "custom", +}); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -12,6 +18,7 @@ export function collectModelTable( displayName: string; describe: string; provider?: LLMModel["provider"]; // Marked as optional + isDefault?: boolean; } > = {}; @@ -23,12 +30,6 @@ export function collectModelTable( }; }); - const customProvider = (modelName: string) => ({ - id: modelName, - providerName: "", - providerType: "custom", - }); - // server custom models customModels .split(",") @@ -54,6 +55,27 @@ export function collectModelTable( }; } }); + + return modelTable; +} + +export function collectModelTableWithDefaultModel( + models: readonly LLMModel[], + customModels: string, + defaultModel: string, +) { + let modelTable = collectModelTable(models, customModels); + if (defaultModel && defaultModel !== "") { + delete modelTable[defaultModel]; + modelTable[defaultModel] = { + name: defaultModel, + displayName: defaultModel, + available: true, + provider: + modelTable[defaultModel]?.provider ?? customProvider(defaultModel), + isDefault: true, + }; + } return modelTable; } @@ -69,3 +91,17 @@ export function collectModels( return allModels; } + +export function collectModelsWithDefaultModel( + models: readonly LLMModel[], + customModels: string, + defaultModel: string, +) { + const modelTable = collectModelTableWithDefaultModel( + models, + customModels, + defaultModel, + ); + const allModels = Object.values(modelTable); + return allModels; +}