mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-11-13 04:33:42 +08:00
完善llama和mistral模型的推理功能
This commit is contained in:
@@ -4,7 +4,6 @@ import {
|
||||
sign,
|
||||
decrypt,
|
||||
getBedrockEndpoint,
|
||||
getModelHeaders,
|
||||
transformBedrockStream,
|
||||
parseEventData,
|
||||
BedrockCredentials,
|
||||
@@ -83,6 +82,10 @@ async function requestBedrock(req: NextRequest) {
|
||||
} catch (e) {
|
||||
throw new Error(`Invalid JSON in request body: ${e}`);
|
||||
}
|
||||
console.log(
|
||||
"[Bedrock Request] original Body:",
|
||||
JSON.stringify(bodyJson, null, 2),
|
||||
);
|
||||
|
||||
// Extract tool configuration if present
|
||||
let tools: any[] | undefined;
|
||||
@@ -97,18 +100,44 @@ async function requestBedrock(req: NextRequest) {
|
||||
modelId,
|
||||
shouldStream,
|
||||
);
|
||||
const additionalHeaders = getModelHeaders(modelId);
|
||||
|
||||
console.log("[Bedrock Request] Endpoint:", endpoint);
|
||||
console.log("[Bedrock Request] Model ID:", modelId);
|
||||
|
||||
// Only include tools for Claude models
|
||||
const isClaudeModel = modelId.toLowerCase().includes("claude3");
|
||||
// Handle tools for different models
|
||||
const isMistralModel = modelId.toLowerCase().includes("mistral");
|
||||
const isClaudeModel = modelId.toLowerCase().includes("claude");
|
||||
|
||||
const requestBody = {
|
||||
...bodyJson,
|
||||
...(isClaudeModel && tools && { tools }),
|
||||
};
|
||||
|
||||
if (tools && tools.length > 0) {
|
||||
if (isClaudeModel) {
|
||||
// Claude models already have correct tool format
|
||||
requestBody.tools = tools;
|
||||
} else if (isMistralModel) {
|
||||
// Format messages for Mistral
|
||||
if (typeof requestBody.prompt === "string") {
|
||||
requestBody.messages = [
|
||||
{ role: "user", content: requestBody.prompt },
|
||||
];
|
||||
delete requestBody.prompt;
|
||||
}
|
||||
|
||||
// Add tools in Mistral's format
|
||||
requestBody.tool_choice = "auto";
|
||||
requestBody.tools = tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema,
|
||||
},
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Sign request
|
||||
const headers = await sign({
|
||||
method: "POST",
|
||||
@@ -119,12 +148,11 @@ async function requestBedrock(req: NextRequest) {
|
||||
body: JSON.stringify(requestBody),
|
||||
service: "bedrock",
|
||||
isStreaming: shouldStream,
|
||||
additionalHeaders,
|
||||
});
|
||||
|
||||
// Make request to AWS Bedrock
|
||||
console.log(
|
||||
"[Bedrock Request] Body:",
|
||||
"[Bedrock Request] Final Body:",
|
||||
JSON.stringify(requestBody, null, 2),
|
||||
);
|
||||
const res = await fetch(endpoint, {
|
||||
@@ -173,11 +201,15 @@ async function requestBedrock(req: NextRequest) {
|
||||
|
||||
// Handle streaming response
|
||||
const transformedStream = transformBedrockStream(res.body, modelId);
|
||||
const encoder = new TextEncoder();
|
||||
const stream = new ReadableStream({
|
||||
async start(controller) {
|
||||
try {
|
||||
for await (const chunk of transformedStream) {
|
||||
controller.enqueue(new TextEncoder().encode(chunk));
|
||||
// Ensure we're sending non-empty chunks
|
||||
if (chunk && chunk.trim()) {
|
||||
controller.enqueue(encoder.encode(chunk));
|
||||
}
|
||||
}
|
||||
controller.close();
|
||||
} catch (err) {
|
||||
|
||||
Reference in New Issue
Block a user