完善mistral tool use功能 和llama3消息格式问题

This commit is contained in:
glay
2024-11-25 20:08:21 +08:00
parent 15d0600642
commit e6633753a4
4 changed files with 151 additions and 147 deletions

View File

@@ -245,6 +245,7 @@ export async function sign({
export function parseEventData(chunk: Uint8Array): any {
const decoder = new TextDecoder();
const text = decoder.decode(chunk);
// console.info("[AWS Parse ] parsing:", text);
try {
const parsed = JSON.parse(text);
// AWS Bedrock wraps the response in a 'body' field
@@ -317,7 +318,10 @@ export function extractMessage(res: any, modelId: string = ""): string {
// Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) {
return res?.outputs?.[0]?.text || "";
if (res.choices?.[0]?.message?.content) {
return res.choices[0].message.content;
}
return res.output || "";
}
// Handle Llama model response format
@@ -334,9 +338,7 @@ export async function* transformBedrockStream(
modelId: string,
) {
const reader = stream.getReader();
let accumulatedText = "";
let toolCallStarted = false;
let currentToolCall = null;
let toolInput = "";
try {
while (true) {
@@ -349,90 +351,54 @@ export async function* transformBedrockStream(
// console.log("parseEventData=========================");
// console.log(parsed);
// Handle Mistral models
if (modelId.toLowerCase().includes("mistral")) {
// If we have content, accumulate it
if (
parsed.choices?.[0]?.message?.role === "assistant" &&
parsed.choices?.[0]?.message?.content
) {
accumulatedText += parsed.choices?.[0]?.message?.content;
// console.log("accumulatedText=========================");
// console.log(accumulatedText);
// Check for tool call in the accumulated text
if (!toolCallStarted && accumulatedText.includes("```json")) {
const jsonMatch = accumulatedText.match(
/```json\s*({[\s\S]*?})\s*```/,
);
if (jsonMatch) {
try {
const toolData = JSON.parse(jsonMatch[1]);
currentToolCall = {
id: `tool-${Date.now()}`,
name: toolData.name,
arguments: toolData.arguments,
};
// Handle tool calls
if (parsed.choices?.[0]?.message?.tool_calls) {
const toolCalls = parsed.choices[0].message.tool_calls;
for (const toolCall of toolCalls) {
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
type: "tool_use",
id: toolCall.id || `tool-${Date.now()}`,
name: toolCall.function?.name,
},
})}\n\n`;
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
type: "tool_use",
id: currentToolCall.id,
name: currentToolCall.name,
},
})}\n\n`;
// Emit tool arguments
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: JSON.stringify(currentToolCall.arguments),
},
})}\n\n`;
// Emit tool call stop
yield `data: ${JSON.stringify({
type: "content_block_stop",
})}\n\n`;
// Clear the accumulated text after processing the tool call
accumulatedText = accumulatedText.replace(
/```json\s*{[\s\S]*?}\s*```/,
"",
);
toolCallStarted = false;
currentToolCall = null;
} catch (e) {
console.error("Failed to parse tool JSON:", e);
}
// Emit tool arguments
if (toolCall.function?.arguments) {
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: toolCall.function.arguments,
},
})}\n\n`;
}
}
// emit the text content if it's not empty
if (parsed.choices?.[0]?.message?.content.trim()) {
// Emit tool call stop
yield `data: ${JSON.stringify({
delta: { text: parsed.choices?.[0]?.message?.content },
})}\n\n`;
}
// Handle stop reason if present
if (parsed.choices?.[0]?.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].stop_reason },
type: "content_block_stop",
})}\n\n`;
}
continue;
}
}
// Handle Llama models
else if (modelId.toLowerCase().includes("llama")) {
if (parsed.generation) {
// Handle regular content
const content = parsed.choices?.[0]?.message?.content;
if (content?.trim()) {
yield `data: ${JSON.stringify({
delta: { text: parsed.generation },
delta: { text: content },
})}\n\n`;
}
if (parsed.stop_reason) {
// Handle stop reason
if (parsed.choices?.[0]?.finish_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason },
delta: { stop_reason: parsed.choices[0].finish_reason },
})}\n\n`;
}
}
@@ -469,8 +435,9 @@ export async function* transformBedrockStream(
})}\n\n`;
}
}
} else {
// Handle other model text responses
}
// Handle other models
else {
const text = parsed.outputText || parsed.generation || "";
if (text) {
yield `data: ${JSON.stringify({