mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-11-14 05:03:43 +08:00
完善llama和mistral模型的推理功能
This commit is contained in:
183
app/utils/aws.ts
183
app/utils/aws.ts
@@ -75,7 +75,6 @@ export interface SignParams {
|
||||
body: string;
|
||||
service: string;
|
||||
isStreaming?: boolean;
|
||||
additionalHeaders?: Record<string, string>;
|
||||
}
|
||||
|
||||
function hmac(
|
||||
@@ -160,7 +159,6 @@ export async function sign({
|
||||
body,
|
||||
service,
|
||||
isStreaming = true,
|
||||
additionalHeaders = {},
|
||||
}: SignParams): Promise<Record<string, string>> {
|
||||
try {
|
||||
const endpoint = new URL(url);
|
||||
@@ -181,7 +179,6 @@ export async function sign({
|
||||
host: endpoint.host,
|
||||
"x-amz-content-sha256": payloadHash,
|
||||
"x-amz-date": amzDate,
|
||||
...additionalHeaders,
|
||||
};
|
||||
|
||||
if (isStreaming) {
|
||||
@@ -311,32 +308,25 @@ export function getBedrockEndpoint(
|
||||
return endpoint;
|
||||
}
|
||||
|
||||
export function getModelHeaders(modelId: string): Record<string, string> {
|
||||
if (!modelId) {
|
||||
throw new Error("Model ID is required for headers");
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
|
||||
if (
|
||||
modelId.startsWith("us.meta.llama") ||
|
||||
modelId.startsWith("mistral.mistral")
|
||||
) {
|
||||
headers["content-type"] = "application/json";
|
||||
headers["accept"] = "application/json";
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
export function extractMessage(res: any, modelId: string = ""): string {
|
||||
if (!res) {
|
||||
console.error("[AWS Extract Error] extractMessage Empty response");
|
||||
return "";
|
||||
}
|
||||
console.log("[Response] extractMessage response: ", res);
|
||||
return res?.content?.[0]?.text;
|
||||
return "";
|
||||
|
||||
// Handle Mistral model response format
|
||||
if (modelId.toLowerCase().includes("mistral")) {
|
||||
return res?.outputs?.[0]?.text || "";
|
||||
}
|
||||
|
||||
// Handle Llama model response format
|
||||
if (modelId.toLowerCase().includes("llama")) {
|
||||
return res?.generation || "";
|
||||
}
|
||||
|
||||
// Handle Claude and other models
|
||||
return res?.content?.[0]?.text || "";
|
||||
}
|
||||
|
||||
export async function* transformBedrockStream(
|
||||
@@ -344,58 +334,105 @@ export async function* transformBedrockStream(
|
||||
modelId: string,
|
||||
) {
|
||||
const reader = stream.getReader();
|
||||
let buffer = "";
|
||||
let accumulatedText = "";
|
||||
let toolCallStarted = false;
|
||||
let currentToolCall = null;
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
if (buffer) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: buffer },
|
||||
})}\n\n`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (done) break;
|
||||
|
||||
const parsed = parseEventData(value);
|
||||
if (!parsed) continue;
|
||||
|
||||
// Handle Titan models
|
||||
if (modelId.startsWith("amazon.titan")) {
|
||||
const text = parsed.outputText || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
// Handle LLaMA3 models
|
||||
else if (modelId.startsWith("us.meta.llama3")) {
|
||||
let text = "";
|
||||
if (parsed.generation) {
|
||||
text = parsed.generation;
|
||||
} else if (parsed.output) {
|
||||
text = parsed.output;
|
||||
} else if (typeof parsed === "string") {
|
||||
text = parsed;
|
||||
}
|
||||
|
||||
if (text) {
|
||||
// Clean up any control characters or invalid JSON characters
|
||||
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
console.log("parseEventData=========================");
|
||||
console.log(parsed);
|
||||
// Handle Mistral models
|
||||
else if (modelId.startsWith("mistral.mistral")) {
|
||||
const text =
|
||||
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || "";
|
||||
if (text) {
|
||||
if (modelId.toLowerCase().includes("mistral")) {
|
||||
// If we have content, accumulate it
|
||||
if (
|
||||
parsed.choices?.[0]?.message?.role === "assistant" &&
|
||||
parsed.choices?.[0]?.message?.content
|
||||
) {
|
||||
accumulatedText += parsed.choices?.[0]?.message?.content;
|
||||
console.log("accumulatedText=========================");
|
||||
console.log(accumulatedText);
|
||||
// Check for tool call in the accumulated text
|
||||
if (!toolCallStarted && accumulatedText.includes("```json")) {
|
||||
const jsonMatch = accumulatedText.match(
|
||||
/```json\s*({[\s\S]*?})\s*```/,
|
||||
);
|
||||
if (jsonMatch) {
|
||||
try {
|
||||
const toolData = JSON.parse(jsonMatch[1]);
|
||||
currentToolCall = {
|
||||
id: `tool-${Date.now()}`,
|
||||
name: toolData.name,
|
||||
arguments: toolData.arguments,
|
||||
};
|
||||
|
||||
// Emit tool call start
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: currentToolCall.id,
|
||||
name: currentToolCall.name,
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool arguments
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: JSON.stringify(currentToolCall.arguments),
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool call stop
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_stop",
|
||||
})}\n\n`;
|
||||
|
||||
// Clear the accumulated text after processing the tool call
|
||||
accumulatedText = accumulatedText.replace(
|
||||
/```json\s*{[\s\S]*?}\s*```/,
|
||||
"",
|
||||
);
|
||||
toolCallStarted = false;
|
||||
currentToolCall = null;
|
||||
} catch (e) {
|
||||
console.error("Failed to parse tool JSON:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// emit the text content if it's not empty
|
||||
if (parsed.choices?.[0]?.message?.content.trim()) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: parsed.choices?.[0]?.message?.content },
|
||||
})}\n\n`;
|
||||
}
|
||||
// Handle stop reason if present
|
||||
if (parsed.choices?.[0]?.stop_reason) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { stop_reason: parsed.choices[0].stop_reason },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle Llama models
|
||||
else if (modelId.toLowerCase().includes("llama")) {
|
||||
if (parsed.generation) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
delta: { text: parsed.generation },
|
||||
})}\n\n`;
|
||||
}
|
||||
if (parsed.stop_reason) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { stop_reason: parsed.stop_reason },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
@@ -423,6 +460,22 @@ export async function* transformBedrockStream(
|
||||
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
||||
} else if (parsed.type === "content_block_stop") {
|
||||
yield `data: ${JSON.stringify(parsed)}\n\n`;
|
||||
} else {
|
||||
// Handle regular text responses
|
||||
const text = parsed.response || parsed.output || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle other model text responses
|
||||
const text = parsed.outputText || parsed.generation || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user