完善llama和mistral模型的推理功能

This commit is contained in:
glay
2024-11-24 23:54:04 +08:00
parent 2ccdd1706a
commit 6f7a635030
4 changed files with 204 additions and 89 deletions

View File

@@ -75,7 +75,6 @@ export interface SignParams {
body: string;
service: string;
isStreaming?: boolean;
additionalHeaders?: Record<string, string>;
}
function hmac(
@@ -160,7 +159,6 @@ export async function sign({
body,
service,
isStreaming = true,
additionalHeaders = {},
}: SignParams): Promise<Record<string, string>> {
try {
const endpoint = new URL(url);
@@ -181,7 +179,6 @@ export async function sign({
host: endpoint.host,
"x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate,
...additionalHeaders,
};
if (isStreaming) {
@@ -311,32 +308,25 @@ export function getBedrockEndpoint(
return endpoint;
}
export function getModelHeaders(modelId: string): Record<string, string> {
if (!modelId) {
throw new Error("Model ID is required for headers");
}
const headers: Record<string, string> = {};
if (
modelId.startsWith("us.meta.llama") ||
modelId.startsWith("mistral.mistral")
) {
headers["content-type"] = "application/json";
headers["accept"] = "application/json";
}
return headers;
}
export function extractMessage(res: any, modelId: string = ""): string {
if (!res) {
console.error("[AWS Extract Error] extractMessage Empty response");
return "";
}
console.log("[Response] extractMessage response: ", res);
return res?.content?.[0]?.text;
return "";
// Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) {
return res?.outputs?.[0]?.text || "";
}
// Handle Llama model response format
if (modelId.toLowerCase().includes("llama")) {
return res?.generation || "";
}
// Handle Claude and other models
return res?.content?.[0]?.text || "";
}
export async function* transformBedrockStream(
@@ -344,58 +334,105 @@ export async function* transformBedrockStream(
modelId: string,
) {
const reader = stream.getReader();
let buffer = "";
let accumulatedText = "";
let toolCallStarted = false;
let currentToolCall = null;
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
if (buffer) {
yield `data: ${JSON.stringify({
delta: { text: buffer },
})}\n\n`;
}
break;
}
if (done) break;
const parsed = parseEventData(value);
if (!parsed) continue;
// Handle Titan models
if (modelId.startsWith("amazon.titan")) {
const text = parsed.outputText || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
// Handle LLaMA3 models
else if (modelId.startsWith("us.meta.llama3")) {
let text = "";
if (parsed.generation) {
text = parsed.generation;
} else if (parsed.output) {
text = parsed.output;
} else if (typeof parsed === "string") {
text = parsed;
}
if (text) {
// Clean up any control characters or invalid JSON characters
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
console.log("parseEventData=========================");
console.log(parsed);
// Handle Mistral models
else if (modelId.startsWith("mistral.mistral")) {
const text =
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || "";
if (text) {
if (modelId.toLowerCase().includes("mistral")) {
// If we have content, accumulate it
if (
parsed.choices?.[0]?.message?.role === "assistant" &&
parsed.choices?.[0]?.message?.content
) {
accumulatedText += parsed.choices?.[0]?.message?.content;
console.log("accumulatedText=========================");
console.log(accumulatedText);
// Check for tool call in the accumulated text
if (!toolCallStarted && accumulatedText.includes("```json")) {
const jsonMatch = accumulatedText.match(
/```json\s*({[\s\S]*?})\s*```/,
);
if (jsonMatch) {
try {
const toolData = JSON.parse(jsonMatch[1]);
currentToolCall = {
id: `tool-${Date.now()}`,
name: toolData.name,
arguments: toolData.arguments,
};
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
type: "tool_use",
id: currentToolCall.id,
name: currentToolCall.name,
},
})}\n\n`;
// Emit tool arguments
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: JSON.stringify(currentToolCall.arguments),
},
})}\n\n`;
// Emit tool call stop
yield `data: ${JSON.stringify({
type: "content_block_stop",
})}\n\n`;
// Clear the accumulated text after processing the tool call
accumulatedText = accumulatedText.replace(
/```json\s*{[\s\S]*?}\s*```/,
"",
);
toolCallStarted = false;
currentToolCall = null;
} catch (e) {
console.error("Failed to parse tool JSON:", e);
}
}
}
// emit the text content if it's not empty
if (parsed.choices?.[0]?.message?.content.trim()) {
yield `data: ${JSON.stringify({
delta: { text: parsed.choices?.[0]?.message?.content },
})}\n\n`;
}
// Handle stop reason if present
if (parsed.choices?.[0]?.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].stop_reason },
})}\n\n`;
}
}
}
// Handle Llama models
else if (modelId.toLowerCase().includes("llama")) {
if (parsed.generation) {
yield `data: ${JSON.stringify({
delta: { text },
delta: { text: parsed.generation },
})}\n\n`;
}
if (parsed.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason },
})}\n\n`;
}
}
@@ -423,6 +460,22 @@ export async function* transformBedrockStream(
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else if (parsed.type === "content_block_stop") {
yield `data: ${JSON.stringify(parsed)}\n\n`;
} else {
// Handle regular text responses
const text = parsed.response || parsed.output || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
} else {
// Handle other model text responses
const text = parsed.outputText || parsed.generation || "";
if (text) {
yield `data: ${JSON.stringify({
delta: { text },
})}\n\n`;
}
}
}