Merge pull request #118 from sijinhui/dev

Dev
This commit is contained in:
sijinhui 2024-07-12 11:38:21 +08:00 committed by GitHub
commit c0f7bc50fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 61 additions and 47 deletions

View File

@ -91,34 +91,14 @@ async function request(req: NextRequest) {
); );
const fetchUrl = `${baseUrl}${path}`; const fetchUrl = `${baseUrl}${path}`;
const clonedBody = await req.text();
const { messages, model, stream, top_p, ...rest } = JSON.parse(
clonedBody,
) as RequestPayload;
const requestBody = {
model,
input: {
messages,
},
parameters: {
...rest,
top_p: top_p === 1 ? 0.99 : top_p, // qwen top_p is should be < 1
result_format: "message",
incremental_output: true,
},
};
const fetchOptions: RequestInit = { const fetchOptions: RequestInit = {
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
Authorization: req.headers.get("Authorization") ?? "", Authorization: req.headers.get("Authorization") ?? "",
"X-DashScope-SSE": stream ? "enable" : "disable", "X-DashScope-SSE": req.headers.get("X-DashScope-SSE") ?? "disable",
}, },
method: req.method, method: req.method,
body: JSON.stringify(requestBody), body: req.body,
redirect: "manual", redirect: "manual",
// @ts-ignore // @ts-ignore
duplex: "half", duplex: "half",
@ -128,18 +108,23 @@ async function request(req: NextRequest) {
// #1815 try to refuse some request to some models // #1815 try to refuse some request to some models
if (serverConfig.customModels && req.body) { if (serverConfig.customModels && req.body) {
try { try {
const clonedBody = await req.text();
fetchOptions.body = clonedBody;
const jsonBody = JSON.parse(clonedBody) as { model?: string };
// not undefined and is false // not undefined and is false
if ( if (
isModelAvailableInServer( isModelAvailableInServer(
serverConfig.customModels, serverConfig.customModels,
model as string, jsonBody?.model as string,
ServiceProvider.Alibaba as string, ServiceProvider.Alibaba as string,
) )
) { ) {
return NextResponse.json( return NextResponse.json(
{ {
error: true, error: true,
message: `you are not allowed to use ${model} model`, message: `you are not allowed to use ${jsonBody?.model} model`,
}, },
{ {
status: 403, status: 403,

View File

@ -2,7 +2,6 @@ import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider } from "@/app/constant"; import { ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { NextApiResponse, NextApiRequest } from "next";
import { auth } from "../../auth"; import { auth } from "../../auth";
import { requestOpenai } from "../../common"; import { requestOpenai } from "../../common";

View File

@ -32,19 +32,25 @@ export interface OpenAIListModelResponse {
}>; }>;
} }
interface RequestPayload { interface RequestInput {
messages: { messages: {
role: "system" | "user" | "assistant"; role: "system" | "user" | "assistant";
content: string | MultimodalContent[]; content: string | MultimodalContent[];
}[]; }[];
stream?: boolean; }
model: string; interface RequestParam {
result_format: string;
incremental_output?: boolean;
temperature: number; temperature: number;
presence_penalty: number; repetition_penalty?: number;
frequency_penalty: number;
top_p: number; top_p: number;
max_tokens?: number; max_tokens?: number;
} }
interface RequestPayload {
model: string;
input: RequestInput;
parameters: RequestParam;
}
export class QwenApi implements LLMApi { export class QwenApi implements LLMApi {
path(path: string): string { path(path: string): string {
@ -91,17 +97,21 @@ export class QwenApi implements LLMApi {
}, },
}; };
const shouldStream = !!options.config.stream;
const requestPayload: RequestPayload = { const requestPayload: RequestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model, model: modelConfig.model,
temperature: modelConfig.temperature, input: {
presence_penalty: modelConfig.presence_penalty, messages,
frequency_penalty: modelConfig.frequency_penalty, },
top_p: modelConfig.top_p, parameters: {
result_format: "message",
incremental_output: shouldStream,
temperature: modelConfig.temperature,
// max_tokens: modelConfig.max_tokens,
top_p: modelConfig.top_p === 1 ? 0.99 : modelConfig.top_p, // qwen top_p is should be < 1
},
}; };
const shouldStream = !!options.config.stream;
const controller = new AbortController(); const controller = new AbortController();
options.onController?.(controller); options.onController?.(controller);
@ -111,7 +121,10 @@ export class QwenApi implements LLMApi {
method: "POST", method: "POST",
body: JSON.stringify(requestPayload), body: JSON.stringify(requestPayload),
signal: controller.signal, signal: controller.signal,
headers: getHeaders(), headers: {
...getHeaders(),
"X-DashScope-SSE": shouldStream ? "enable" : "disable",
},
}; };
// make a fetch request // make a fetch request

View File

@ -609,7 +609,7 @@ export function ChatActions(props: {
<ChatAction <ChatAction
onClick={() => setShowModelSelector(true)} onClick={() => setShowModelSelector(true)}
text={currentModel} text={currentModelName}
icon={<RobotIcon />} icon={<RobotIcon />}
/> />
@ -627,7 +627,7 @@ export function ChatActions(props: {
{/*/>*/} {/*/>*/}
{showModelSelector && ( {showModelSelector && (
<Selector <ModalSelector
defaultSelectedValue={`${currentModel}@${currentProviderName}`} defaultSelectedValue={`${currentModel}@${currentProviderName}`}
items={models.map((m) => ({ items={models.map((m) => ({
title: `${m.displayName}${ title: `${m.displayName}${

View File

@ -23,6 +23,7 @@ import {
NARROW_SIDEBAR_WIDTH, NARROW_SIDEBAR_WIDTH,
Path, Path,
REPO_URL, REPO_URL,
ServiceProvider,
} from "../constant"; } from "../constant";
import { Link, useNavigate } from "react-router-dom"; import { Link, useNavigate } from "react-router-dom";
@ -131,6 +132,10 @@ export function SideBar(props: { className?: string }) {
const chatStore = useChatStore(); const chatStore = useChatStore();
const currentModel = chatStore.currentSession().mask.modelConfig.model; const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentProviderName =
chatStore.currentSession().mask.modelConfig?.providerName ||
ServiceProvider.OpenAI;
// drag side bar // drag side bar
const { onDragStart, shouldNarrow } = useDragSideBar(); const { onDragStart, shouldNarrow } = useDragSideBar();
const navigate = useNavigate(); const navigate = useNavigate();
@ -249,7 +254,11 @@ export function SideBar(props: { className?: string }) {
text={shouldNarrow ? undefined : Locale.Home.NewChat} text={shouldNarrow ? undefined : Locale.Home.NewChat}
onClick={() => { onClick={() => {
if (config.dontShowMaskSplashScreen) { if (config.dontShowMaskSplashScreen) {
chatStore.newSession(undefined, currentModel); chatStore.newSession(
undefined,
currentModel,
currentProviderName,
);
navigate(Path.Chat); navigate(Path.Chat);
} else { } else {
navigate(Path.NewChat); navigate(Path.NewChat);

View File

@ -514,7 +514,7 @@ export function ModalSelector<T extends CheckGroupValueType>(props: {
onClose?: () => void; onClose?: () => void;
multiple?: boolean; multiple?: boolean;
}) { }) {
// console.log("-----", props); console.log("-----", props);
const getCheckCardAvatar = (value: string): React.ReactNode => { const getCheckCardAvatar = (value: string): React.ReactNode => {
if (value.startsWith("gpt")) { if (value.startsWith("gpt")) {

View File

@ -224,14 +224,22 @@ export const useChatStore = createPersistStore(
}); });
}, },
newSession(mask?: Mask, currentModel?: Mask["modelConfig"]["model"]) { newSession(
mask?: Mask,
currentModel?: Mask["modelConfig"]["model"],
currentProviderName?: ServiceProvider,
) {
const session = createEmptySession(); const session = createEmptySession();
const config = useAppConfig.getState(); const config = useAppConfig.getState();
// console.log("------", session, "2222", config); // console.log("------", session, "2222", config);
// 继承当前会话的模型 // 继承当前会话的模型,
// 新增继承模型提供者
if (currentModel) { if (currentModel) {
session.mask.modelConfig.model = currentModel; session.mask.modelConfig.model = currentModel;
} }
if (currentProviderName) {
session.mask.modelConfig.providerName = currentProviderName;
}
if (mask) { if (mask) {
const config = useAppConfig.getState(); const config = useAppConfig.getState();
const globalModelConfig = config.modelConfig; const globalModelConfig = config.modelConfig;

View File

@ -140,7 +140,7 @@ export const useAppConfig = createPersistStore(
}), }),
{ {
name: StoreKey.Config, name: StoreKey.Config,
version: 3.96, version: 3.97,
migrate(persistedState, version) { migrate(persistedState, version) {
const state = persistedState as ChatConfig; const state = persistedState as ChatConfig;
@ -176,7 +176,7 @@ export const useAppConfig = createPersistStore(
// return { ...DEFAULT_CONFIG }; // return { ...DEFAULT_CONFIG };
// } // }
if (version < 3.96) { if (version < 3.97) {
state.modelConfig = DEFAULT_CONFIG.modelConfig; state.modelConfig = DEFAULT_CONFIG.modelConfig;
// state.modelConfig.template = // state.modelConfig.template =
// state.modelConfig.template !== DEFAULT_INPUT_TEMPLATE // state.modelConfig.template !== DEFAULT_INPUT_TEMPLATE

View File

@ -62,7 +62,7 @@ export function collectModelTable(
modelTable[fullName]["available"] = available; modelTable[fullName]["available"] = available;
// swap name and displayName for bytedance // swap name and displayName for bytedance
if (providerName === "bytedance") { if (providerName === "bytedance") {
[name, displayName] = [displayName, name]; [name, displayName] = [displayName, modelName];
modelTable[fullName]["name"] = name; modelTable[fullName]["name"] = name;
} }
if (displayName) { if (displayName) {