完善mistral tool use功能 和llama3消息格式问题

This commit is contained in:
glay
2024-11-25 20:08:21 +08:00
parent 15d0600642
commit e6633753a4
4 changed files with 151 additions and 147 deletions

View File

@@ -21,7 +21,22 @@ const ClaudeMapper = {
system: "user",
} as const;
const MistralMapper = {
system: "system",
user: "user",
assistant: "assistant",
} as const;
type ClaudeRole = keyof typeof ClaudeMapper;
type MistralRole = keyof typeof MistralMapper;
interface Tool {
function?: {
name?: string;
description?: string;
parameters?: any;
};
}
export class BedrockApi implements LLMApi {
speech(options: SpeechOptions): Promise<ArrayBuffer> {
@@ -30,7 +45,6 @@ export class BedrockApi implements LLMApi {
formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) {
const model = modelConfig.model;
const visionModel = isVisionModel(modelConfig.model);
// Handle Titan models
@@ -53,37 +67,27 @@ export class BedrockApi implements LLMApi {
// Handle LLaMA models
if (model.includes("meta.llama")) {
// Format conversation for Llama models
let prompt = "";
let systemPrompt = "";
let prompt = "<|begin_of_text|>";
// Extract system message if present
const systemMessage = messages.find((m) => m.role === "system");
if (systemMessage) {
systemPrompt = getMessageTextContent(systemMessage);
prompt += `<|start_header_id|>system<|end_header_id|>\n${getMessageTextContent(
systemMessage,
)}<|eot_id|>`;
}
// Format the conversation
const conversationMessages = messages.filter((m) => m.role !== "system");
prompt = `<s>[INST] <<SYS>>\n${
systemPrompt || "You are a helpful, respectful and honest assistant."
}\n<</SYS>>\n\n`;
for (let i = 0; i < conversationMessages.length; i++) {
const message = conversationMessages[i];
for (const message of conversationMessages) {
const role = message.role === "assistant" ? "assistant" : "user";
const content = getMessageTextContent(message);
if (i === 0 && message.role === "user") {
// First user message goes in the same [INST] block as system prompt
prompt += `${content} [/INST]`;
} else {
if (message.role === "user") {
prompt += `\n\n<s>[INST] ${content} [/INST]`;
} else {
prompt += ` ${content} </s>`;
}
}
prompt += `<|start_header_id|>${role}<|end_header_id|>\n${content}<|eot_id|>`;
}
// Add the final assistant header to prompt completion
prompt += "<|start_header_id|>assistant<|end_header_id|>";
return {
prompt,
max_gen_len: modelConfig.max_tokens || 512,
@@ -94,9 +98,8 @@ export class BedrockApi implements LLMApi {
// Handle Mistral models
if (model.startsWith("mistral.mistral")) {
// Format messages for Mistral's chat format
const formattedMessages = messages.map((message) => ({
role: message.role,
role: MistralMapper[message.role as MistralRole] || "user",
content: getMessageTextContent(message),
}));
@@ -234,6 +237,11 @@ export class BedrockApi implements LLMApi {
});
const finalRequestBody = this.formatRequestBody(messages, modelConfig);
console.log(
"[Bedrock Client] Request Body:",
JSON.stringify(finalRequestBody, null, 2),
);
if (shouldStream) {
let index = -1;
const [tools, funcs] = usePluginStore
@@ -253,6 +261,7 @@ export class BedrockApi implements LLMApi {
})),
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
let chunkJson:
@@ -304,36 +313,73 @@ export class BedrockApi implements LLMApi {
) => {
// reset index value
index = -1;
// @ts-ignore
requestPayload?.messages?.splice(
const modelId = modelConfig.model;
const isMistral = modelId.startsWith("mistral.mistral");
const isClaude = modelId.includes("anthropic.claude");
if (isClaude) {
// Format for Claude
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
type: "tool_use",
id: tool.id,
name: tool?.function?.name,
input: tool?.function?.arguments
? JSON.parse(tool?.function?.arguments)
: {},
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: result.tool_call_id,
content: result.content,
},
],
})),
);
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
type: "tool_use",
id: tool.id,
name: tool?.function?.name,
input: tool?.function?.arguments
? JSON.parse(tool?.function?.arguments)
: {},
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: result.tool_call_id,
content: result.content,
},
],
})),
);
} else if (isMistral) {
// Format for Mistral
requestPayload?.messages?.splice(
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: "",
// @ts-ignore
tool_calls: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
id: tool.id,
function: {
name: tool?.function?.name,
arguments: tool?.function?.arguments || "{}",
},
}),
),
},
...toolCallResult.map((result) => ({
role: "tool",
tool_call_id: result.tool_call_id,
content: result.content,
})),
);
} else {
console.warn(
`[Bedrock Client] Unhandled model type for tool calls: ${modelId}`,
);
}
},
options,
);
@@ -368,6 +414,7 @@ export class BedrockApi implements LLMApi {
options.onError?.(e as Error);
}
}
path(path: string): string {
const accessStore = useAccessStore.getState();
let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : "";