mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-11-13 04:33:42 +08:00
完善llama和mistral模型的推理功能
This commit is contained in:
@@ -15,7 +15,8 @@ function parseEventData(chunk: Uint8Array): any {
|
||||
// AWS Bedrock wraps the response in a 'body' field
|
||||
if (typeof parsed.body === "string") {
|
||||
try {
|
||||
return JSON.parse(parsed.body);
|
||||
const bodyJson = JSON.parse(parsed.body);
|
||||
return bodyJson;
|
||||
} catch (e) {
|
||||
return { output: parsed.body };
|
||||
}
|
||||
@@ -89,10 +90,12 @@ async function* transformBedrockStream(
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
// Handle LLaMA3 models
|
||||
else if (modelId.startsWith("us.meta.llama3")) {
|
||||
// Handle LLaMA models
|
||||
else if (modelId.startsWith("us.meta.llama")) {
|
||||
let text = "";
|
||||
if (parsed.generation) {
|
||||
if (parsed.outputs?.[0]?.text) {
|
||||
text = parsed.outputs[0].text;
|
||||
} else if (parsed.generation) {
|
||||
text = parsed.generation;
|
||||
} else if (parsed.output) {
|
||||
text = parsed.output;
|
||||
@@ -101,8 +104,6 @@ async function* transformBedrockStream(
|
||||
}
|
||||
|
||||
if (text) {
|
||||
// Clean up any control characters or invalid JSON characters
|
||||
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
@@ -162,6 +163,7 @@ async function* transformBedrockStream(
|
||||
function validateRequest(body: any, modelId: string): void {
|
||||
if (!modelId) throw new Error("Model ID is required");
|
||||
|
||||
// Handle nested body structure
|
||||
const bodyContent = body.body || body;
|
||||
|
||||
if (modelId.startsWith("anthropic.claude")) {
|
||||
@@ -180,22 +182,20 @@ function validateRequest(body: any, modelId: string): void {
|
||||
} else if (typeof body.prompt !== "string") {
|
||||
throw new Error("prompt is required for Claude 2 and earlier");
|
||||
}
|
||||
} else if (modelId.startsWith("us.meta.llama3")) {
|
||||
if (!bodyContent.prompt) {
|
||||
throw new Error("prompt is required for LLaMA3 models");
|
||||
} else if (modelId.startsWith("us.meta.llama")) {
|
||||
if (!bodyContent.prompt || typeof bodyContent.prompt !== "string") {
|
||||
throw new Error("prompt string is required for LLaMA models");
|
||||
}
|
||||
if (
|
||||
!bodyContent.max_gen_len ||
|
||||
typeof bodyContent.max_gen_len !== "number"
|
||||
) {
|
||||
throw new Error("max_gen_len must be a positive number for LLaMA models");
|
||||
}
|
||||
} else if (modelId.startsWith("mistral.mistral")) {
|
||||
if (!bodyContent.prompt) {
|
||||
throw new Error("prompt is required for Mistral models");
|
||||
}
|
||||
if (
|
||||
!bodyContent.prompt.startsWith("<s>[INST]") ||
|
||||
!bodyContent.prompt.includes("[/INST]")
|
||||
) {
|
||||
throw new Error(
|
||||
"Mistral prompt must be wrapped in <s>[INST] and [/INST] tags",
|
||||
);
|
||||
}
|
||||
} else if (modelId.startsWith("amazon.titan")) {
|
||||
if (!bodyContent.inputText) throw new Error("Titan requires inputText");
|
||||
}
|
||||
@@ -250,7 +250,6 @@ async function requestBedrock(req: NextRequest) {
|
||||
try {
|
||||
// Determine the endpoint and request body based on model type
|
||||
let endpoint;
|
||||
let requestBody;
|
||||
|
||||
const bodyText = await req.clone().text();
|
||||
if (!bodyText) {
|
||||
@@ -258,6 +257,10 @@ async function requestBedrock(req: NextRequest) {
|
||||
}
|
||||
|
||||
const bodyJson = JSON.parse(bodyText);
|
||||
|
||||
// Debug log the request body
|
||||
console.log("Original request body:", JSON.stringify(bodyJson, null, 2));
|
||||
|
||||
validateRequest(bodyJson, modelId);
|
||||
|
||||
// For all models, use standard endpoints
|
||||
@@ -267,26 +270,44 @@ async function requestBedrock(req: NextRequest) {
|
||||
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
|
||||
}
|
||||
|
||||
// Set content type and accept headers for Mistral models
|
||||
// Set additional headers based on model type
|
||||
const additionalHeaders: Record<string, string> = {};
|
||||
if (
|
||||
modelId.startsWith("us.meta.llama") ||
|
||||
modelId.startsWith("mistral.mistral")
|
||||
) {
|
||||
additionalHeaders["content-type"] = "application/json";
|
||||
additionalHeaders["accept"] = "application/json";
|
||||
}
|
||||
|
||||
// For Mistral models, unwrap the body object
|
||||
const finalRequestBody =
|
||||
modelId.startsWith("mistral.mistral") && bodyJson.body
|
||||
? bodyJson.body
|
||||
: bodyJson;
|
||||
|
||||
// Set content type and accept headers for specific models
|
||||
const headers = await sign({
|
||||
method: "POST",
|
||||
url: endpoint,
|
||||
region: awsRegion,
|
||||
accessKeyId: awsAccessKey,
|
||||
secretAccessKey: awsSecretKey,
|
||||
body: JSON.stringify(bodyJson.body || bodyJson),
|
||||
body: JSON.stringify(finalRequestBody),
|
||||
service: "bedrock",
|
||||
isStreaming: shouldStream !== "false",
|
||||
...(modelId.startsWith("mistral.mistral") && {
|
||||
contentType: "application/json",
|
||||
accept: "application/json",
|
||||
}),
|
||||
additionalHeaders,
|
||||
});
|
||||
|
||||
// Debug log the final request body
|
||||
// console.log("Final request endpoint:", endpoint);
|
||||
// console.log(headers);
|
||||
// console.log("Final request body:", JSON.stringify(finalRequestBody, null, 2));
|
||||
|
||||
const res = await fetch(endpoint, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(bodyJson.body || bodyJson),
|
||||
body: JSON.stringify(finalRequestBody),
|
||||
redirect: "manual",
|
||||
// @ts-ignore
|
||||
duplex: "half",
|
||||
|
||||
Reference in New Issue
Block a user