完善mistral模型的推理结果

This commit is contained in:
glay
2024-11-23 16:27:19 +08:00
parent a6337e9f23
commit 238eb70986
2 changed files with 81 additions and 39 deletions

View File

@@ -110,8 +110,17 @@ async function* transformBedrockStream(
}
// Handle Mistral models
else if (modelId.startsWith("mistral.mistral")) {
const text =
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || "";
let text = "";
if (parsed.outputs?.[0]?.text) {
text = parsed.outputs[0].text;
} else if (parsed.output) {
text = parsed.output;
} else if (parsed.completion) {
text = parsed.completion;
} else if (typeof parsed === "string") {
text = parsed;
}
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
@@ -176,7 +185,17 @@ function validateRequest(body: any, modelId: string): void {
throw new Error("prompt is required for LLaMA3 models");
}
} else if (modelId.startsWith("mistral.mistral")) {
if (!bodyContent.prompt) throw new Error("Mistral requires a prompt");
if (!bodyContent.prompt) {
throw new Error("prompt is required for Mistral models");
}
if (
!bodyContent.prompt.startsWith("<s>[INST]") ||
!bodyContent.prompt.includes("[/INST]")
) {
throw new Error(
"Mistral prompt must be wrapped in <s>[INST] and [/INST] tags",
);
}
} else if (modelId.startsWith("amazon.titan")) {
if (!bodyContent.inputText) throw new Error("Titan requires inputText");
}
@@ -247,29 +266,27 @@ async function requestBedrock(req: NextRequest) {
} else {
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
}
requestBody = JSON.stringify(bodyJson.body || bodyJson);
// console.log("Request to AWS Bedrock:", {
// endpoint,
// modelId,
// body: requestBody,
// });
// Set content type and accept headers for Mistral models
const headers = await sign({
method: "POST",
url: endpoint,
region: awsRegion,
accessKeyId: awsAccessKey,
secretAccessKey: awsSecretKey,
body: requestBody,
body: JSON.stringify(bodyJson.body || bodyJson),
service: "bedrock",
isStreaming: shouldStream !== "false",
...(modelId.startsWith("mistral.mistral") && {
contentType: "application/json",
accept: "application/json",
}),
});
const res = await fetch(endpoint, {
method: "POST",
headers,
body: requestBody,
body: JSON.stringify(bodyJson.body || bodyJson),
redirect: "manual",
// @ts-ignore
duplex: "half",