diff --git a/app/api/alibaba/[...path]/route.ts b/app/api/alibaba/[...path]/route.ts index b2c42ac78..c97ce5934 100644 --- a/app/api/alibaba/[...path]/route.ts +++ b/app/api/alibaba/[...path]/route.ts @@ -91,34 +91,14 @@ async function request(req: NextRequest) { ); const fetchUrl = `${baseUrl}${path}`; - - const clonedBody = await req.text(); - - const { messages, model, stream, top_p, ...rest } = JSON.parse( - clonedBody, - ) as RequestPayload; - - const requestBody = { - model, - input: { - messages, - }, - parameters: { - ...rest, - top_p: top_p === 1 ? 0.99 : top_p, // qwen top_p is should be < 1 - result_format: "message", - incremental_output: true, - }, - }; - const fetchOptions: RequestInit = { headers: { "Content-Type": "application/json", Authorization: req.headers.get("Authorization") ?? "", - "X-DashScope-SSE": stream ? "enable" : "disable", + "X-DashScope-SSE": req.headers.get("X-DashScope-SSE") ?? "disable", }, method: req.method, - body: JSON.stringify(requestBody), + body: req.body, redirect: "manual", // @ts-ignore duplex: "half", @@ -128,18 +108,23 @@ async function request(req: NextRequest) { // #1815 try to refuse some request to some models if (serverConfig.customModels && req.body) { try { + const clonedBody = await req.text(); + fetchOptions.body = clonedBody; + + const jsonBody = JSON.parse(clonedBody) as { model?: string }; + // not undefined and is false if ( isModelAvailableInServer( serverConfig.customModels, - model as string, + jsonBody?.model as string, ServiceProvider.Alibaba as string, ) ) { return NextResponse.json( { error: true, - message: `you are not allowed to use ${model} model`, + message: `you are not allowed to use ${jsonBody?.model} model`, }, { status: 403, diff --git a/app/api/azure/[...path]/route.ts b/app/api/azure/[...path]/route.ts index 474ee761d..8cdaf2157 100644 --- a/app/api/azure/[...path]/route.ts +++ b/app/api/azure/[...path]/route.ts @@ -2,7 +2,6 @@ import { getServerSideConfig } from "@/app/config/server"; import { ModelProvider } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; -import { NextApiResponse, NextApiRequest } from "next"; import { auth } from "../../auth"; import { requestOpenai } from "../../common"; diff --git a/app/client/platforms/alibaba.ts b/app/client/platforms/alibaba.ts index 72126d728..723ba774b 100644 --- a/app/client/platforms/alibaba.ts +++ b/app/client/platforms/alibaba.ts @@ -32,19 +32,25 @@ export interface OpenAIListModelResponse { }>; } -interface RequestPayload { +interface RequestInput { messages: { role: "system" | "user" | "assistant"; content: string | MultimodalContent[]; }[]; - stream?: boolean; - model: string; +} +interface RequestParam { + result_format: string; + incremental_output?: boolean; temperature: number; - presence_penalty: number; - frequency_penalty: number; + repetition_penalty?: number; top_p: number; max_tokens?: number; } +interface RequestPayload { + model: string; + input: RequestInput; + parameters: RequestParam; +} export class QwenApi implements LLMApi { path(path: string): string { @@ -91,17 +97,21 @@ export class QwenApi implements LLMApi { }, }; + const shouldStream = !!options.config.stream; const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, + input: { + messages, + }, + parameters: { + result_format: "message", + incremental_output: shouldStream, + temperature: modelConfig.temperature, + // max_tokens: modelConfig.max_tokens, + top_p: modelConfig.top_p === 1 ? 0.99 : modelConfig.top_p, // qwen top_p is should be < 1 + }, }; - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); @@ -111,7 +121,10 @@ export class QwenApi implements LLMApi { method: "POST", body: JSON.stringify(requestPayload), signal: controller.signal, - headers: getHeaders(), + headers: { + ...getHeaders(), + "X-DashScope-SSE": shouldStream ? "enable" : "disable", + }, }; // make a fetch request diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 1c5809831..a114ee5fe 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -609,7 +609,7 @@ export function ChatActions(props: { setShowModelSelector(true)} - text={currentModel} + text={currentModelName} icon={} /> @@ -627,7 +627,7 @@ export function ChatActions(props: { {/*/>*/} {showModelSelector && ( - ({ title: `${m.displayName}${ diff --git a/app/components/sidebar.tsx b/app/components/sidebar.tsx index 509943bb9..aa527bc4f 100644 --- a/app/components/sidebar.tsx +++ b/app/components/sidebar.tsx @@ -23,6 +23,7 @@ import { NARROW_SIDEBAR_WIDTH, Path, REPO_URL, + ServiceProvider, } from "../constant"; import { Link, useNavigate } from "react-router-dom"; @@ -131,6 +132,10 @@ export function SideBar(props: { className?: string }) { const chatStore = useChatStore(); const currentModel = chatStore.currentSession().mask.modelConfig.model; + const currentProviderName = + chatStore.currentSession().mask.modelConfig?.providerName || + ServiceProvider.OpenAI; + // drag side bar const { onDragStart, shouldNarrow } = useDragSideBar(); const navigate = useNavigate(); @@ -249,7 +254,11 @@ export function SideBar(props: { className?: string }) { text={shouldNarrow ? undefined : Locale.Home.NewChat} onClick={() => { if (config.dontShowMaskSplashScreen) { - chatStore.newSession(undefined, currentModel); + chatStore.newSession( + undefined, + currentModel, + currentProviderName, + ); navigate(Path.Chat); } else { navigate(Path.NewChat); diff --git a/app/components/ui-lib.tsx b/app/components/ui-lib.tsx index 8c55f5620..19b93a7c9 100644 --- a/app/components/ui-lib.tsx +++ b/app/components/ui-lib.tsx @@ -514,7 +514,7 @@ export function ModalSelector(props: { onClose?: () => void; multiple?: boolean; }) { - // console.log("-----", props); + console.log("-----", props); const getCheckCardAvatar = (value: string): React.ReactNode => { if (value.startsWith("gpt")) { diff --git a/app/store/chat.ts b/app/store/chat.ts index 837f434e1..7d19d7e8d 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -224,14 +224,22 @@ export const useChatStore = createPersistStore( }); }, - newSession(mask?: Mask, currentModel?: Mask["modelConfig"]["model"]) { + newSession( + mask?: Mask, + currentModel?: Mask["modelConfig"]["model"], + currentProviderName?: ServiceProvider, + ) { const session = createEmptySession(); const config = useAppConfig.getState(); // console.log("------", session, "2222", config); - // 继承当前会话的模型 + // 继承当前会话的模型, + // 新增继承模型提供者 if (currentModel) { session.mask.modelConfig.model = currentModel; } + if (currentProviderName) { + session.mask.modelConfig.providerName = currentProviderName; + } if (mask) { const config = useAppConfig.getState(); const globalModelConfig = config.modelConfig; diff --git a/app/store/config.ts b/app/store/config.ts index 568fef76a..dbb845e5e 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -140,7 +140,7 @@ export const useAppConfig = createPersistStore( }), { name: StoreKey.Config, - version: 3.96, + version: 3.97, migrate(persistedState, version) { const state = persistedState as ChatConfig; @@ -176,7 +176,7 @@ export const useAppConfig = createPersistStore( // return { ...DEFAULT_CONFIG }; // } - if (version < 3.96) { + if (version < 3.97) { state.modelConfig = DEFAULT_CONFIG.modelConfig; // state.modelConfig.template = // state.modelConfig.template !== DEFAULT_INPUT_TEMPLATE diff --git a/app/utils/model.ts b/app/utils/model.ts index 70c53b811..a0b6f6630 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -62,7 +62,7 @@ export function collectModelTable( modelTable[fullName]["available"] = available; // swap name and displayName for bytedance if (providerName === "bytedance") { - [name, displayName] = [displayName, name]; + [name, displayName] = [displayName, modelName]; modelTable[fullName]["name"] = name; } if (displayName) {