mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-11-14 05:03:43 +08:00
完善llama和mistral模型的推理功能
This commit is contained in:
@@ -85,14 +85,17 @@ export class BedrockApi implements LLMApi {
|
||||
}
|
||||
|
||||
// Handle LLaMA models
|
||||
if (modelId.startsWith("us.meta.llama3")) {
|
||||
if (modelId.startsWith("us.meta.llama")) {
|
||||
if (res?.delta?.text) {
|
||||
return res.delta.text;
|
||||
}
|
||||
if (res?.generation) {
|
||||
return res.generation;
|
||||
}
|
||||
if (typeof res?.output === "string") {
|
||||
if (res?.outputs?.[0]?.text) {
|
||||
return res.outputs[0].text;
|
||||
}
|
||||
if (res?.output) {
|
||||
return res.output;
|
||||
}
|
||||
if (typeof res === "string") {
|
||||
@@ -103,11 +106,28 @@ export class BedrockApi implements LLMApi {
|
||||
|
||||
// Handle Mistral models
|
||||
if (modelId.startsWith("mistral.mistral")) {
|
||||
if (res?.delta?.text) return res.delta.text;
|
||||
return res?.outputs?.[0]?.text || res?.output || res?.completion || "";
|
||||
if (res?.delta?.text) {
|
||||
return res.delta.text;
|
||||
}
|
||||
if (res?.outputs?.[0]?.text) {
|
||||
return res.outputs[0].text;
|
||||
}
|
||||
if (res?.content?.[0]?.text) {
|
||||
return res.content[0].text;
|
||||
}
|
||||
if (res?.output) {
|
||||
return res.output;
|
||||
}
|
||||
if (res?.completion) {
|
||||
return res.completion;
|
||||
}
|
||||
if (typeof res === "string") {
|
||||
return res;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// Handle Claude models and fallback cases
|
||||
// Handle Claude models
|
||||
if (res?.content?.[0]?.text) return res.content[0].text;
|
||||
if (res?.messages?.[0]?.content?.[0]?.text)
|
||||
return res.messages[0].content[0].text;
|
||||
@@ -142,14 +162,11 @@ export class BedrockApi implements LLMApi {
|
||||
]
|
||||
: messages;
|
||||
|
||||
// Format messages without role prefixes for Titan
|
||||
const inputText = allMessages
|
||||
.map((m) => {
|
||||
// Include system message as a prefix instruction
|
||||
if (m.role === "system") {
|
||||
return getMessageTextContent(m);
|
||||
}
|
||||
// For user/assistant messages, just include the content
|
||||
return getMessageTextContent(m);
|
||||
})
|
||||
.join("\n\n");
|
||||
@@ -166,25 +183,39 @@ export class BedrockApi implements LLMApi {
|
||||
};
|
||||
}
|
||||
|
||||
// Handle LLaMA3 models
|
||||
if (model.startsWith("us.meta.llama3")) {
|
||||
// Only include the last user message for LLaMA
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const prompt = getMessageTextContent(lastMessage);
|
||||
// Handle LLaMA models
|
||||
if (model.startsWith("us.meta.llama")) {
|
||||
const allMessages = systemMessage
|
||||
? [
|
||||
{ role: "system" as MessageRole, content: systemMessage },
|
||||
...messages,
|
||||
]
|
||||
: messages;
|
||||
|
||||
const prompt = allMessages
|
||||
.map((m) => {
|
||||
const content = getMessageTextContent(m);
|
||||
if (m.role === "system") {
|
||||
return `System: ${content}`;
|
||||
} else if (m.role === "user") {
|
||||
return `User: ${content}`;
|
||||
} else if (m.role === "assistant") {
|
||||
return `Assistant: ${content}`;
|
||||
}
|
||||
return content;
|
||||
})
|
||||
.join("\n\n");
|
||||
|
||||
return {
|
||||
contentType: "application/json",
|
||||
accept: "application/json",
|
||||
body: {
|
||||
prompt: prompt,
|
||||
max_gen_len: modelConfig.max_tokens || 256,
|
||||
temperature: modelConfig.temperature || 0.5,
|
||||
top_p: 0.9,
|
||||
},
|
||||
prompt,
|
||||
max_gen_len: modelConfig.max_tokens || 512,
|
||||
temperature: modelConfig.temperature || 0.6,
|
||||
top_p: modelConfig.top_p || 0.9,
|
||||
stop: ["User:", "System:", "Assistant:", "\n\n"],
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Mistral models with correct instruction format
|
||||
// Handle Mistral models
|
||||
if (model.startsWith("mistral.mistral")) {
|
||||
const allMessages = systemMessage
|
||||
? [
|
||||
@@ -193,25 +224,29 @@ export class BedrockApi implements LLMApi {
|
||||
]
|
||||
: messages;
|
||||
|
||||
// Format messages as a conversation with instruction tags
|
||||
const prompt = `<s>[INST] ${allMessages
|
||||
.map((m) => getMessageTextContent(m))
|
||||
.join("\n")} [/INST]`;
|
||||
const formattedConversation = allMessages
|
||||
.map((m) => {
|
||||
const content = getMessageTextContent(m);
|
||||
if (m.role === "system") {
|
||||
return content;
|
||||
} else if (m.role === "user") {
|
||||
return content;
|
||||
} else if (m.role === "assistant") {
|
||||
return content;
|
||||
}
|
||||
return content;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
// Format according to Mistral's requirements
|
||||
return {
|
||||
contentType: "application/json",
|
||||
accept: "application/json",
|
||||
body: {
|
||||
prompt,
|
||||
max_tokens: modelConfig.max_tokens || 4096,
|
||||
temperature: modelConfig.temperature || 0.5,
|
||||
top_p: 0.9,
|
||||
top_k: 50,
|
||||
},
|
||||
prompt: formattedConversation,
|
||||
max_tokens: modelConfig.max_tokens || 4096,
|
||||
temperature: modelConfig.temperature || 0.7,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Claude models (existing implementation)
|
||||
// Handle Claude models
|
||||
const isClaude3 = model.startsWith("anthropic.claude-3");
|
||||
const formattedMessages = messages
|
||||
.filter(
|
||||
@@ -253,12 +288,14 @@ export class BedrockApi implements LLMApi {
|
||||
});
|
||||
|
||||
return {
|
||||
anthropic_version: "bedrock-2023-05-31",
|
||||
max_tokens: modelConfig.max_tokens,
|
||||
messages: formattedMessages,
|
||||
...(systemMessage && { system: systemMessage }),
|
||||
temperature: modelConfig.temperature,
|
||||
...(isClaude3 && { top_k: 5 }),
|
||||
body: {
|
||||
anthropic_version: "bedrock-2023-05-31",
|
||||
max_tokens: modelConfig.max_tokens,
|
||||
messages: formattedMessages,
|
||||
...(systemMessage && { system: systemMessage }),
|
||||
temperature: modelConfig.temperature,
|
||||
...(isClaude3 && { top_k: modelConfig.top_k || 50 }),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -301,6 +338,13 @@ export class BedrockApi implements LLMApi {
|
||||
const headers = getHeaders();
|
||||
headers.ModelID = modelConfig.model;
|
||||
|
||||
// For LLaMA and Mistral models, send the request body directly without the 'body' wrapper
|
||||
const finalRequestBody =
|
||||
modelConfig.model.startsWith("us.meta.llama") ||
|
||||
modelConfig.model.startsWith("mistral.mistral")
|
||||
? requestBody
|
||||
: requestBody.body;
|
||||
|
||||
if (options.config.stream) {
|
||||
let index = -1;
|
||||
let currentToolArgs = "";
|
||||
@@ -312,7 +356,7 @@ export class BedrockApi implements LLMApi {
|
||||
|
||||
return stream(
|
||||
chatPath,
|
||||
requestBody,
|
||||
finalRequestBody,
|
||||
headers,
|
||||
(tools as ToolDefinition[]).map((tool) => ({
|
||||
name: tool?.function?.name,
|
||||
@@ -420,7 +464,7 @@ export class BedrockApi implements LLMApi {
|
||||
const res = await fetch(chatPath, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
body: JSON.stringify(finalRequestBody),
|
||||
});
|
||||
|
||||
const resJson = await res.json();
|
||||
|
||||
Reference in New Issue
Block a user