修改: app/api/bedrock.ts

修改:     app/client/platforms/bedrock.ts
This commit is contained in:
glay
2024-11-05 17:28:19 +08:00
parent afbf5eb541
commit 58837f6dec
2 changed files with 193 additions and 535 deletions

View File

@@ -16,6 +16,7 @@ import {
import { getMessageTextContent, isVisionModel } from "../../utils";
import { fetch } from "../../utils/stream";
import { preProcessImageContent, stream } from "../../utils/chat";
import { RequestPayload } from "./openai";
export type MultiBlockContent = {
type: "image" | "text";
@@ -39,12 +40,6 @@ const ClaudeMapper = {
} as const;
export class BedrockApi implements LLMApi {
usage(): Promise<LLMUsage> {
throw new Error("Method not implemented.");
}
models(): Promise<LLMModel[]> {
throw new Error("Method not implemented.");
}
speech(options: SpeechOptions): Promise<ArrayBuffer> {
throw new Error("Speech not implemented for Bedrock.");
}
@@ -149,34 +144,15 @@ export class BedrockApi implements LLMApi {
});
}
const [tools, funcs] = usePluginStore
.getState()
.getAsTools(useChatStore.getState().currentSession().mask?.plugin || []);
const requestBody = {
modelId: options.config.model,
messages: messages.filter((msg) => msg.content.length > 0),
messages: prompt,
inferenceConfig: {
maxTokens: modelConfig.max_tokens,
temperature: modelConfig.temperature,
topP: modelConfig.top_p,
stopSequences: [],
},
toolConfig:
Array.isArray(tools) && tools.length > 0
? {
tools: tools.map((tool: any) => ({
toolSpec: {
name: tool?.function?.name,
description: tool?.function?.description,
inputSchema: {
json: tool?.function?.parameters,
},
},
})),
toolChoice: { auto: {} },
}
: undefined,
};
const conversePath = `${ApiPath.Bedrock}/converse`;
@@ -185,83 +161,80 @@ export class BedrockApi implements LLMApi {
if (shouldStream) {
let currentToolUse: ChatMessageTool | null = null;
let index = -1;
const [tools, funcs] = usePluginStore
.getState()
.getAsTools(
useChatStore.getState().currentSession().mask?.plugin || [],
);
return stream(
conversePath,
requestBody,
getHeaders(),
Array.isArray(tools)
? tools.map((tool: any) => ({
name: tool?.function?.name,
description: tool?.function?.description,
input_schema: tool?.function?.parameters,
}))
: [],
// @ts-ignore
tools.map((tool) => ({
name: tool?.function?.name,
description: tool?.function?.description,
input_schema: tool?.function?.parameters,
})),
funcs,
controller,
// parseSSE
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
const parsed = JSON.parse(text);
const event = parsed.stream;
// console.log("parseSSE", text, runTools);
let chunkJson:
| undefined
| {
type: "content_block_delta" | "content_block_stop";
content_block?: {
type: "tool_use";
id: string;
name: string;
};
delta?: {
type: "text_delta" | "input_json_delta";
text?: string;
partial_json?: string;
};
index: number;
};
chunkJson = JSON.parse(text);
if (!event) {
console.warn("[Bedrock] Unexpected event format:", parsed);
return "";
}
if (event.messageStart) {
return "";
}
if (event.contentBlockStart?.start?.toolUse) {
const { toolUseId, name } = event.contentBlockStart.start.toolUse;
currentToolUse = {
id: toolUseId,
if (chunkJson?.content_block?.type == "tool_use") {
index += 1;
const id = chunkJson?.content_block.id;
const name = chunkJson?.content_block.name;
runTools.push({
id,
type: "function",
function: {
name,
arguments: "",
},
};
runTools.push(currentToolUse);
return "";
});
}
if (event.contentBlockDelta?.delta?.text) {
return event.contentBlockDelta.delta.text;
}
if (
event.contentBlockDelta?.delta?.toolUse?.input &&
currentToolUse?.function
chunkJson?.delta?.type == "input_json_delta" &&
chunkJson?.delta?.partial_json
) {
currentToolUse.function.arguments +=
event.contentBlockDelta.delta.toolUse.input;
return "";
// @ts-ignore
runTools[index]["function"]["arguments"] +=
chunkJson?.delta?.partial_json;
}
if (
event.internalServerException ||
event.modelStreamErrorException ||
event.validationException ||
event.throttlingException ||
event.serviceUnavailableException
) {
const errorMessage =
event.internalServerException?.message ||
event.modelStreamErrorException?.message ||
event.validationException?.message ||
event.throttlingException?.message ||
event.serviceUnavailableException?.message ||
"Unknown error";
throw new Error(errorMessage);
}
return "";
return chunkJson?.delta?.text;
},
// processToolMessage
(requestPayload: any, toolCallMessage: any, toolCallResult: any[]) => {
currentToolUse = null;
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
toolCallMessage: any,
toolCallResult: any[],
) => {
// reset index value
index = -1;
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
@@ -277,6 +250,7 @@ export class BedrockApi implements LLMApi {
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
@@ -292,26 +266,33 @@ export class BedrockApi implements LLMApi {
options,
);
} else {
const payload = {
method: "POST",
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
...getHeaders(), // get common headers
},
};
try {
const response = await fetch(conversePath, {
method: "POST",
headers: getHeaders(),
body: JSON.stringify(requestBody),
signal: controller.signal,
});
controller.signal.onabort = () => options.onFinish("");
if (!response.ok) {
const error = await response.text();
throw new Error(`Bedrock API error: ${error}`);
}
const res = await fetch(conversePath, payload);
const resJson = await res.json();
const responseBody = await response.json();
const content = this.extractMessage(responseBody);
options.onFinish(content);
} catch (e: any) {
console.error("[Bedrock] Chat error:", e);
throw e;
const message = this.extractMessage(resJson);
options.onFinish(message);
} catch (e) {
console.error("failed to chat", e);
options.onError?.(e as Error);
}
}
}
usage(): Promise<LLMUsage> {
throw new Error("Method not implemented.");
}
models(): Promise<LLMModel[]> {
throw new Error("Method not implemented.");
}
}