完善mistral模型的推理结果

This commit is contained in:
glay
2024-11-23 16:27:19 +08:00
parent a6337e9f23
commit 238eb70986
2 changed files with 81 additions and 39 deletions

View File

@@ -74,16 +74,30 @@ export class BedrockApi implements LLMApi {
try {
// Handle Titan models
if (modelId.startsWith("amazon.titan")) {
if (res?.delta?.text) return res.delta.text;
return res?.outputText || "";
let text = "";
if (res?.delta?.text) {
text = res.delta.text;
} else {
text = res?.outputText || "";
}
// Clean up Titan response by removing leading question mark and whitespace
return text.replace(/^[\s?]+/, "");
}
// Handle LLaMA models
if (modelId.startsWith("us.meta.llama3")) {
if (res?.delta?.text) return res.delta.text;
if (res?.generation) return res.generation;
if (typeof res?.output === "string") return res.output;
if (typeof res === "string") return res;
if (res?.delta?.text) {
return res.delta.text;
}
if (res?.generation) {
return res.generation;
}
if (typeof res?.output === "string") {
return res.output;
}
if (typeof res === "string") {
return res;
}
return "";
}
@@ -127,9 +141,19 @@ export class BedrockApi implements LLMApi {
...messages,
]
: messages;
// Format messages without role prefixes for Titan
const inputText = allMessages
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
.join("\n");
.map((m) => {
// Include system message as a prefix instruction
if (m.role === "system") {
return getMessageTextContent(m);
}
// For user/assistant messages, just include the content
return getMessageTextContent(m);
})
.join("\n\n");
return {
body: {
inputText,
@@ -142,29 +166,25 @@ export class BedrockApi implements LLMApi {
};
}
// Handle LLaMA3 models - simplified format
// Handle LLaMA3 models
if (model.startsWith("us.meta.llama3")) {
const allMessages = systemMessage
? [
{ role: "system" as MessageRole, content: systemMessage },
...messages,
]
: messages;
const prompt = allMessages
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
.join("\n");
// Only include the last user message for LLaMA
const lastMessage = messages[messages.length - 1];
const prompt = getMessageTextContent(lastMessage);
return {
contentType: "application/json",
accept: "application/json",
body: {
prompt,
prompt: prompt,
max_gen_len: modelConfig.max_tokens || 256,
temperature: modelConfig.temperature || 0.5,
top_p: 0.9,
},
};
}
// Handle Mistral models
// Handle Mistral models with correct instruction format
if (model.startsWith("mistral.mistral")) {
const allMessages = systemMessage
? [
@@ -172,14 +192,21 @@ export class BedrockApi implements LLMApi {
...messages,
]
: messages;
const prompt = allMessages
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
.join("\n");
// Format messages as a conversation with instruction tags
const prompt = `<s>[INST] ${allMessages
.map((m) => getMessageTextContent(m))
.join("\n")} [/INST]`;
return {
contentType: "application/json",
accept: "application/json",
body: {
prompt,
temperature: modelConfig.temperature || 0.7,
max_tokens: modelConfig.max_tokens || 4096,
temperature: modelConfig.temperature || 0.5,
top_p: 0.9,
top_k: 50,
},
};
}
@@ -258,7 +285,6 @@ export class BedrockApi implements LLMApi {
systemMessage,
modelConfig,
);
// console.log("Request body:", JSON.stringify(requestBody, null, 2));
const controller = new AbortController();
options.onController?.(controller);
@@ -338,7 +364,6 @@ export class BedrockApi implements LLMApi {
} catch (e) {}
}
const message = this.extractMessage(chunkJson, modelConfig.model);
// console.log("Extracted message:", message);
return message;
} catch (e) {
console.error("Error parsing chunk:", e);