完善llama和mistral模型的推理功能

This commit is contained in:
glay
2024-11-24 23:54:04 +08:00
parent 2ccdd1706a
commit 6f7a635030
4 changed files with 204 additions and 89 deletions

View File

@@ -4,7 +4,6 @@ import {
sign,
decrypt,
getBedrockEndpoint,
getModelHeaders,
transformBedrockStream,
parseEventData,
BedrockCredentials,
@@ -83,6 +82,10 @@ async function requestBedrock(req: NextRequest) {
} catch (e) {
throw new Error(`Invalid JSON in request body: ${e}`);
}
console.log(
"[Bedrock Request] original Body:",
JSON.stringify(bodyJson, null, 2),
);
// Extract tool configuration if present
let tools: any[] | undefined;
@@ -97,18 +100,44 @@ async function requestBedrock(req: NextRequest) {
modelId,
shouldStream,
);
const additionalHeaders = getModelHeaders(modelId);
console.log("[Bedrock Request] Endpoint:", endpoint);
console.log("[Bedrock Request] Model ID:", modelId);
// Only include tools for Claude models
const isClaudeModel = modelId.toLowerCase().includes("claude3");
// Handle tools for different models
const isMistralModel = modelId.toLowerCase().includes("mistral");
const isClaudeModel = modelId.toLowerCase().includes("claude");
const requestBody = {
...bodyJson,
...(isClaudeModel && tools && { tools }),
};
if (tools && tools.length > 0) {
if (isClaudeModel) {
// Claude models already have correct tool format
requestBody.tools = tools;
} else if (isMistralModel) {
// Format messages for Mistral
if (typeof requestBody.prompt === "string") {
requestBody.messages = [
{ role: "user", content: requestBody.prompt },
];
delete requestBody.prompt;
}
// Add tools in Mistral's format
requestBody.tool_choice = "auto";
requestBody.tools = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}));
}
}
// Sign request
const headers = await sign({
method: "POST",
@@ -119,12 +148,11 @@ async function requestBedrock(req: NextRequest) {
body: JSON.stringify(requestBody),
service: "bedrock",
isStreaming: shouldStream,
additionalHeaders,
});
// Make request to AWS Bedrock
console.log(
"[Bedrock Request] Body:",
"[Bedrock Request] Final Body:",
JSON.stringify(requestBody, null, 2),
);
const res = await fetch(endpoint, {
@@ -173,11 +201,15 @@ async function requestBedrock(req: NextRequest) {
// Handle streaming response
const transformedStream = transformBedrockStream(res.body, modelId);
const encoder = new TextEncoder();
const stream = new ReadableStream({
async start(controller) {
try {
for await (const chunk of transformedStream) {
controller.enqueue(new TextEncoder().encode(chunk));
// Ensure we're sending non-empty chunks
if (chunk && chunk.trim()) {
controller.enqueue(encoder.encode(chunk));
}
}
controller.close();
} catch (err) {