完善llama和mistral模型的推理功能

This commit is contained in:
glay
2024-11-24 23:54:04 +08:00
parent 2ccdd1706a
commit 6f7a635030
4 changed files with 204 additions and 89 deletions

View File

@@ -37,7 +37,7 @@ export class BedrockApi implements LLMApi {
if (model.startsWith("amazon.titan")) {
const inputText = messages
.map((message) => {
return `${message.role}: ${message.content}`;
return `${message.role}: ${getMessageTextContent(message)}`;
})
.join("\n\n");
@@ -52,32 +52,59 @@ export class BedrockApi implements LLMApi {
}
// Handle LLaMA models
if (model.startsWith("us.meta.llama")) {
const prompt = messages
.map((message) => {
return `${message.role}: ${message.content}`;
})
.join("\n\n");
if (model.includes("meta.llama")) {
// Format conversation for Llama models
let prompt = "";
let systemPrompt = "";
// Extract system message if present
const systemMessage = messages.find((m) => m.role === "system");
if (systemMessage) {
systemPrompt = getMessageTextContent(systemMessage);
}
// Format the conversation
const conversationMessages = messages.filter((m) => m.role !== "system");
prompt = `<s>[INST] <<SYS>>\n${
systemPrompt || "You are a helpful, respectful and honest assistant."
}\n<</SYS>>\n\n`;
for (let i = 0; i < conversationMessages.length; i++) {
const message = conversationMessages[i];
const content = getMessageTextContent(message);
if (i === 0 && message.role === "user") {
// First user message goes in the same [INST] block as system prompt
prompt += `${content} [/INST]`;
} else {
if (message.role === "user") {
prompt += `\n\n<s>[INST] ${content} [/INST]`;
} else {
prompt += ` ${content} </s>`;
}
}
}
return {
prompt,
max_gen_len: modelConfig.max_tokens || 512,
temperature: modelConfig.temperature || 0.6,
temperature: modelConfig.temperature || 0.7,
top_p: modelConfig.top_p || 0.9,
stop: ["User:", "System:", "Assistant:", "\n\n"],
};
}
// Handle Mistral models
if (model.startsWith("mistral.mistral")) {
const prompt = messages
.map((message) => {
return `${message.role}: ${message.content}`;
})
.join("\n\n");
// Format messages for Mistral's chat format
const formattedMessages = messages.map((message) => ({
role: message.role,
content: getMessageTextContent(message),
}));
return {
prompt,
messages: formattedMessages,
max_tokens: modelConfig.max_tokens || 4096,
temperature: modelConfig.temperature || 0.7,
top_p: modelConfig.top_p || 0.9,
};
}