完善mistral tool use功能

This commit is contained in:
glay
2024-11-26 10:10:34 +08:00
parent e6633753a4
commit 448babd27f
3 changed files with 331 additions and 62 deletions

View File

@@ -245,7 +245,7 @@ export async function sign({
export function parseEventData(chunk: Uint8Array): any {
const decoder = new TextDecoder();
const text = decoder.decode(chunk);
// console.info("[AWS Parse ] parsing:", text);
try {
const parsed = JSON.parse(text);
// AWS Bedrock wraps the response in a 'body' field
@@ -282,7 +282,6 @@ export function parseEventData(chunk: Uint8Array): any {
// Handle plain text responses
if (text.trim()) {
// Clean up any malformed JSON characters
const cleanText = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
return { output: cleanText };
}
@@ -314,7 +313,6 @@ export function extractMessage(res: any, modelId: string = ""): string {
console.error("[AWS Extract Error] extractMessage Empty response");
return "";
}
// console.log("[Response] extractMessage response: ", res);
// Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) {
@@ -329,6 +327,11 @@ export function extractMessage(res: any, modelId: string = ""): string {
return res?.generation || "";
}
// Handle Titan model response format
if (modelId.toLowerCase().includes("titan")) {
return res?.outputText || "";
}
// Handle Claude and other models
return res?.content?.[0]?.text || "";
}
@@ -338,12 +341,10 @@ export async function* transformBedrockStream(
modelId: string,
) {
const reader = stream.getReader();
let toolInput = "";
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const parsed = parseEventData(value);
@@ -351,14 +352,40 @@ export async function* transformBedrockStream(
// console.log("parseEventData=========================");
// console.log(parsed);
// Handle Claude 3 models
if (modelId.startsWith("anthropic.claude")) {
if (parsed.type === "message_start") {
// Initialize message
continue;
} else if (parsed.type === "content_block_start") {
if (parsed.content_block?.type === "tool_use") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
}
continue;
} else if (parsed.type === "content_block_delta") {
if (parsed.delta?.type === "text_delta") {
yield `data: ${JSON.stringify({
delta: { text: parsed.delta.text },
})}\n\n`;
} else if (parsed.delta?.type === "input_json_delta") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
}
} else if (parsed.type === "content_block_stop") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else if (
parsed.type === "message_delta" &&
parsed.delta?.stop_reason
) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.delta.stop_reason },
})}\n\n`;
}
}
// Handle Mistral models
if (modelId.toLowerCase().includes("mistral")) {
// Handle tool calls
else if (modelId.toLowerCase().includes("mistral")) {
if (parsed.choices?.[0]?.message?.tool_calls) {
const toolCalls = parsed.choices[0].message.tool_calls;
for (const toolCall of toolCalls) {
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
@@ -368,7 +395,6 @@ export async function* transformBedrockStream(
},
})}\n\n`;
// Emit tool arguments
if (toolCall.function?.arguments) {
yield `data: ${JSON.stringify({
type: "content_block_delta",
@@ -379,66 +405,51 @@ export async function* transformBedrockStream(
})}\n\n`;
}
// Emit tool call stop
yield `data: ${JSON.stringify({
type: "content_block_stop",
})}\n\n`;
}
continue;
}
// Handle regular content
const content = parsed.choices?.[0]?.message?.content;
if (content?.trim()) {
} else if (parsed.choices?.[0]?.message?.content) {
yield `data: ${JSON.stringify({
delta: { text: content },
delta: { text: parsed.choices[0].message.content },
})}\n\n`;
}
// Handle stop reason
if (parsed.choices?.[0]?.finish_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].finish_reason },
})}\n\n`;
}
}
// Handle Claude models
else if (modelId.startsWith("anthropic.claude")) {
if (parsed.type === "content_block_delta") {
if (parsed.delta?.type === "text_delta") {
yield `data: ${JSON.stringify({
delta: { text: parsed.delta.text },
})}\n\n`;
} else if (parsed.delta?.type === "input_json_delta") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
}
} else if (
parsed.type === "message_delta" &&
parsed.delta?.stop_reason
) {
// Handle Llama models
else if (modelId.toLowerCase().includes("llama")) {
if (parsed.generation) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.delta.stop_reason },
delta: { text: parsed.generation },
})}\n\n`;
}
if (parsed.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason },
})}\n\n`;
} else if (
parsed.type === "content_block_start" &&
parsed.content_block?.type === "tool_use"
) {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else if (parsed.type === "content_block_stop") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else {
// Handle regular text responses
const text = parsed.response || parsed.output || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
}
// Handle other models
// Handle Titan models
else if (modelId.toLowerCase().includes("titan")) {
if (parsed.outputText) {
yield `data: ${JSON.stringify({
delta: { text: parsed.outputText },
})}\n\n`;
}
if (parsed.completionReason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.completionReason },
})}\n\n`;
}
}
// Handle other models with basic text output
else {
const text = parsed.outputText || parsed.generation || "";
const text = parsed.response || parsed.output || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },