From fc9688a1f78966652fcad3accb731555c771f8b0 Mon Sep 17 00:00:00 2001 From: AC Date: Sun, 6 Apr 2025 00:41:56 +0800 Subject: [PATCH] feat(bedrock): Integrate AWS Bedrock as a new LLM provider Adds support for using models hosted on AWS Bedrock, specifically Anthropic Claude models. Key changes: - Added '@aws-sdk/client-bedrock-runtime' dependency. - Updated constants, server config, and auth logic for Bedrock. - Implemented backend API handler () to communicate with the Bedrock API, handling streaming and non-streaming responses, and formatting output to be OpenAI compatible. - Updated dynamic API router () to dispatch requests to the Bedrock handler. - Created frontend client () and updated client factory (). - Updated with necessary Bedrock environment variables (AWS keys, region, enable flag) and an example for using to alias Bedrock models. --- .env.template | 21 ++- app/api/[provider]/[...path]/route.ts | 3 + app/api/auth.ts | 18 +- app/api/bedrock/index.ts | 241 ++++++++++++++++++++++++++ app/client/api.ts | 9 +- app/client/platforms/bedrock.ts | 140 +++++++++++++++ app/config/server.ts | 22 ++- app/constant.ts | 7 + package.json | 1 + 9 files changed, 443 insertions(+), 19 deletions(-) create mode 100644 app/api/bedrock/index.ts create mode 100644 app/client/platforms/bedrock.ts diff --git a/.env.template b/.env.template index 4efaa2ff8..2dd5265a8 100644 --- a/.env.template +++ b/.env.template @@ -57,7 +57,7 @@ DISABLE_FAST_LINK= # (optional) # Default: Empty # To control custom models, use + to add a custom model, use - to hide a model, use name=displayName to customize model name, separated by comma. -CUSTOM_MODELS= +CUSTOM_MODELS=-all,+gpt-4o-2024-11-20@openai=gpt-4o,+gpt-4o-mini@openai,+us.anthropic.claude-3-5-sonnet-20241022-v2:0@bedrock=sonnet # (optional) # Default: Empty @@ -81,3 +81,22 @@ SILICONFLOW_API_KEY= ### siliconflow Api url (optional) SILICONFLOW_URL= + +# --- AWS Bedrock Section --- +# Ensure these lines for keys either have placeholder values like below, +# are commented out entirely, or removed if they shouldn't be in the template. + +# AWS Access Key for Bedrock API (Example: Use placeholder or comment out) +# AWS_ACCESS_KEY_ID= + +# AWS Secret Access Key for Bedrock API (Example: Use placeholder or comment out) +# AWS_SECRET_ACCESS_KEY= + +# AWS Region for Bedrock API (e.g. us-east-1, us-west-2) +AWS_REGION= + +# Enable AWS Bedrock models (set to "true" to enable) +ENABLE_AWS_BEDROCK= + +# Custom endpoint URL for AWS Bedrock (optional) +AWS_BEDROCK_ENDPOINT= diff --git a/app/api/[provider]/[...path]/route.ts b/app/api/[provider]/[...path]/route.ts index 8975bf971..2c6c8f9ab 100644 --- a/app/api/[provider]/[...path]/route.ts +++ b/app/api/[provider]/[...path]/route.ts @@ -14,6 +14,7 @@ import { handle as deepseekHandler } from "../../deepseek"; import { handle as siliconflowHandler } from "../../siliconflow"; import { handle as xaiHandler } from "../../xai"; import { handle as chatglmHandler } from "../../glm"; +import { handle as bedrockHandler } from "../../bedrock"; import { handle as proxyHandler } from "../../proxy"; async function handle( @@ -50,6 +51,8 @@ async function handle( return chatglmHandler(req, { params }); case ApiPath.SiliconFlow: return siliconflowHandler(req, { params }); + case ApiPath.Bedrock: + return bedrockHandler(req, { params }); case ApiPath.OpenAI: return openaiHandler(req, { params }); default: diff --git a/app/api/auth.ts b/app/api/auth.ts index 8c78c70c8..e5a031f35 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -56,14 +56,6 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { // if user does not provide an api key, inject system api key if (!apiKey) { const serverConfig = getServerSideConfig(); - - // const systemApiKey = - // modelProvider === ModelProvider.GeminiPro - // ? serverConfig.googleApiKey - // : serverConfig.isAzure - // ? serverConfig.azureApiKey - // : serverConfig.apiKey; - let systemApiKey: string | undefined; switch (modelProvider) { @@ -104,6 +96,11 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.SiliconFlow: systemApiKey = serverConfig.siliconFlowApiKey; break; + case ModelProvider.Bedrock: + console.log( + "[Auth] Using AWS credentials for Bedrock, no API key override.", + ); + return { error: false }; case ModelProvider.GPT: default: if (req.nextUrl.pathname.includes("azure/deployments")) { @@ -117,7 +114,10 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { console.log("[Auth] use system api key"); req.headers.set("Authorization", `Bearer ${systemApiKey}`); } else { - console.log("[Auth] admin did not provide an api key"); + console.log( + "[Auth] admin did not provide an api key for provider:", + modelProvider, + ); } } else { console.log("[Auth] use user api key"); diff --git a/app/api/bedrock/index.ts b/app/api/bedrock/index.ts new file mode 100644 index 000000000..c568c886f --- /dev/null +++ b/app/api/bedrock/index.ts @@ -0,0 +1,241 @@ +import { ModelProvider, Bedrock as BedrockConfig } from "@/app/constant"; +import { getServerSideConfig } from "@/app/config/server"; +import { prettyObject } from "@/app/utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "../auth"; +import { + BedrockRuntimeClient, + InvokeModelWithResponseStreamCommand, + InvokeModelCommand, +} from "@aws-sdk/client-bedrock-runtime"; + +const ALLOWED_PATH = new Set([BedrockConfig.ChatPath]); + +// Helper to get AWS Credentials +function getAwsCredentials() { + const config = getServerSideConfig(); + if (!config.isBedrock) { + throw new Error("AWS Bedrock is not configured properly"); + } + return { + accessKeyId: config.bedrockAccessKeyId as string, + secretAccessKey: config.bedrockSecretAccessKey as string, + }; +} + +export async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Bedrock Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const subpath = params.path.join("/"); + + if (!ALLOWED_PATH.has(subpath)) { + console.log("[Bedrock Route] forbidden path ", subpath); + return NextResponse.json( + { error: true, msg: "you are not allowed to request " + subpath }, + { status: 403 }, + ); + } + + // Auth check specifically for Bedrock (might not need header API key) + const authResult = auth(req, ModelProvider.Bedrock); + if (authResult.error) { + return NextResponse.json(authResult, { status: 401 }); + } + + try { + const config = getServerSideConfig(); + if (!config.isBedrock) { + // This check might be redundant due to getAwsCredentials, but good practice + return NextResponse.json( + { error: true, msg: "AWS Bedrock is not configured properly" }, + { status: 500 }, + ); + } + + const bedrockRegion = config.bedrockRegion as string; + const bedrockEndpoint = config.bedrockEndpoint; + + const client = new BedrockRuntimeClient({ + region: bedrockRegion, + credentials: getAwsCredentials(), + endpoint: bedrockEndpoint || undefined, + }); + + const body = await req.json(); + console.log("[Bedrock] request body: ", body); + + const { + messages, + model, + stream = false, + temperature = 0.7, + max_tokens, + } = body; + + // --- Payload formatting for Claude on Bedrock --- + const isClaudeModel = model.includes("anthropic.claude"); + if (!isClaudeModel) { + return NextResponse.json( + { error: true, msg: "Unsupported Bedrock model: " + model }, + { status: 400 }, + ); + } + + const systemPrompts = messages.filter((msg: any) => msg.role === "system"); + const userAssistantMessages = messages.filter( + (msg: any) => msg.role !== "system", + ); + + const payload = { + anthropic_version: "bedrock-2023-05-31", + max_tokens: max_tokens || 4096, + temperature: temperature, + messages: userAssistantMessages.map((msg: any) => ({ + role: msg.role, // 'user' or 'assistant' + content: + typeof msg.content === "string" + ? [{ type: "text", text: msg.content }] + : msg.content, // Assuming MultimodalContent format is compatible + })), + ...(systemPrompts.length > 0 && { + system: systemPrompts.map((msg: any) => msg.content).join("\n"), + }), + }; + // --- End Payload Formatting --- + + if (stream) { + const command = new InvokeModelWithResponseStreamCommand({ + modelId: model, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify(payload), + }); + const response = await client.send(command); + + if (!response.body) { + throw new Error("Empty response stream from Bedrock"); + } + const responseBody = response.body; + + const encoder = new TextEncoder(); + const decoder = new TextDecoder(); + const readableStream = new ReadableStream({ + async start(controller) { + try { + for await (const event of responseBody) { + if (event.chunk?.bytes) { + const chunkData = JSON.parse(decoder.decode(event.chunk.bytes)); + let responseText = ""; + let finishReason: string | null = null; + + if ( + chunkData.type === "content_block_delta" && + chunkData.delta.type === "text_delta" + ) { + responseText = chunkData.delta.text || ""; + } else if (chunkData.type === "message_stop") { + finishReason = + chunkData["amazon-bedrock-invocationMetrics"] + ?.outputTokenCount > 0 + ? "stop" + : "length"; // Example logic + } + + // Format as OpenAI SSE chunk + const sseData = { + id: `chatcmpl-${nanoid()}`, + object: "chat.completion.chunk", + created: Math.floor(Date.now() / 1000), + model: model, + choices: [ + { + index: 0, + delta: { content: responseText }, + finish_reason: finishReason, + }, + ], + }; + controller.enqueue( + encoder.encode(`data: ${JSON.stringify(sseData)}\n\n`), + ); + + if (finishReason) { + controller.enqueue(encoder.encode("data: [DONE]\n\n")); + break; // Exit loop after stop message + } + } + } + } catch (error) { + console.error("[Bedrock] Streaming error:", error); + controller.error(error); + } finally { + controller.close(); + } + }, + }); + + return new NextResponse(readableStream, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }, + }); + } else { + // Non-streaming response + const command = new InvokeModelCommand({ + modelId: model, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify(payload), + }); + const response = await client.send(command); + const responseBody = JSON.parse(new TextDecoder().decode(response.body)); + + // Format response to match OpenAI + const formattedResponse = { + id: `chatcmpl-${nanoid()}`, + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: model, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: responseBody.content?.[0]?.text ?? "", + }, + finish_reason: "stop", // Assuming stop for non-streamed + }, + ], + usage: { + prompt_tokens: + responseBody["amazon-bedrock-invocationMetrics"]?.inputTokenCount ?? + -1, + completion_tokens: + responseBody["amazon-bedrock-invocationMetrics"] + ?.outputTokenCount ?? -1, + total_tokens: + (responseBody["amazon-bedrock-invocationMetrics"] + ?.inputTokenCount ?? 0) + + (responseBody["amazon-bedrock-invocationMetrics"] + ?.outputTokenCount ?? 0) || -1, + }, + }; + return NextResponse.json(formattedResponse); + } + } catch (e) { + console.error("[Bedrock] API Handler Error:", e); + return NextResponse.json(prettyObject(e), { status: 500 }); + } +} + +// Need nanoid for unique IDs +import { nanoid } from "nanoid"; diff --git a/app/client/api.ts b/app/client/api.ts index f5288593d..c642ef14b 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -24,6 +24,7 @@ import { DeepSeekApi } from "./platforms/deepseek"; import { XAIApi } from "./platforms/xai"; import { ChatGLMApi } from "./platforms/glm"; import { SiliconflowApi } from "./platforms/siliconflow"; +import { BedrockApi } from "./platforms/bedrock"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -173,6 +174,9 @@ export class ClientApi { case ModelProvider.SiliconFlow: this.llm = new SiliconflowApi(); break; + case ModelProvider.Bedrock: + this.llm = new BedrockApi(); + break; default: this.llm = new ChatGPTApi(); } @@ -356,7 +360,7 @@ export function getHeaders(ignoreHeaders: boolean = false) { return headers; } -export function getClientApi(provider: ServiceProvider): ClientApi { +export function getClientApi(provider: ServiceProvider | string): ClientApi { switch (provider) { case ServiceProvider.Google: return new ClientApi(ModelProvider.GeminiPro); @@ -382,6 +386,9 @@ export function getClientApi(provider: ServiceProvider): ClientApi { return new ClientApi(ModelProvider.ChatGLM); case ServiceProvider.SiliconFlow: return new ClientApi(ModelProvider.SiliconFlow); + case ServiceProvider.Bedrock: + case "AWS Bedrock": + return new ClientApi(ModelProvider.Bedrock); default: return new ClientApi(ModelProvider.GPT); } diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts new file mode 100644 index 000000000..ebf0b04db --- /dev/null +++ b/app/client/platforms/bedrock.ts @@ -0,0 +1,140 @@ +"use client"; + +import { ApiPath, Bedrock } from "@/app/constant"; +import { LLMApi, ChatOptions, LLMModel, LLMUsage, SpeechOptions } from "../api"; +import { getHeaders } from "../api"; +import { fetch } from "@/app/utils/stream"; + +export class BedrockApi implements LLMApi { + path(path: string): string { + // Route requests to our backend handler + const apiPath = `${ApiPath.Bedrock}/${path}`; + console.log("[BedrockApi] Constructed API path:", apiPath); + return apiPath; + } + + async chat(options: ChatOptions) { + const messages = options.messages; + const modelConfig = options.config; + + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = this.path(Bedrock.ChatPath); + console.log("[BedrockApi] Requesting path:", chatPath); + + const chatPayload = { + method: "POST", + body: JSON.stringify({ + model: modelConfig.model, + messages, + temperature: modelConfig.temperature, + stream: !!modelConfig.stream, + max_tokens: 4096, // Example: You might want to make this configurable + }), + signal: controller.signal, + headers: getHeaders(), // getHeaders should handle Bedrock (no auth needed) + }; + console.log("[BedrockApi] Request payload (excluding messages):", { + model: modelConfig.model, + temperature: modelConfig.temperature, + stream: !!modelConfig.stream, + }); + + // Handle stream response + if (modelConfig.stream) { + const response = await fetch(chatPath, chatPayload); + const reader = response.body?.getReader(); + const decoder = new TextDecoder(); + let messageBuffer = ""; + + if (!reader) { + throw new Error("Response body reader is not available"); + } + + while (true) { + // Loop until stream is done + const { done, value } = await reader.read(); + if (done) break; + + const text = decoder.decode(value, { stream: true }); // Decode chunk + const lines = text.split("\n"); + + for (const line of lines) { + if (!line.startsWith("data:")) continue; + const jsonData = line.substring("data:".length).trim(); + if (jsonData === "[DONE]") break; // End of stream + if (!jsonData) continue; + + try { + const data = JSON.parse(jsonData); + const content = data.choices?.[0]?.delta?.content ?? ""; + const finishReason = data.choices?.[0]?.finish_reason; + + if (content) { + messageBuffer += content; + options.onUpdate?.(messageBuffer, content); + } + if (finishReason) { + // Potentially handle finish reason if needed + console.log( + "[BedrockApi] Stream finished with reason:", + finishReason, + ); + break; // Exit inner loop on finish signal within a chunk + } + } catch (e) { + console.error( + "[BedrockApi] Error parsing stream chunk:", + jsonData, + e, + ); + } + } + } + reader.releaseLock(); // Release reader lock + options.onFinish(messageBuffer, response); + } else { + // Handle non-streaming response + const response = await fetch(chatPath, chatPayload); + if (!response.ok) { + const errorBody = await response.text(); + console.error( + "[BedrockApi] Non-stream error response:", + response.status, + errorBody, + ); + throw new Error( + `Request failed with status ${response.status}: ${errorBody}`, + ); + } + const responseJson = await response.json(); + const content = responseJson.choices?.[0]?.message?.content ?? ""; + options.onFinish(content, response); + } + } catch (e) { + console.error("[BedrockApi] Chat request failed:", e); + options.onError?.(e as Error); + } + } + + async usage(): Promise { + // Bedrock usage reporting might require separate implementation if available + return { + used: 0, + total: Number.MAX_SAFE_INTEGER, // Indicate no limit or unknown + }; + } + + async models(): Promise { + // Fetching models dynamically from Bedrock is complex and usually not needed + // Rely on the hardcoded models in constant.ts + return []; + } + + async speech(options: SpeechOptions): Promise { + // Implement if Bedrock TTS is needed + throw new Error("Speech synthesis not supported for Bedrock yet"); + } +} diff --git a/app/config/server.ts b/app/config/server.ts index 43d4ff833..05bd2aa8d 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -163,19 +163,18 @@ export const getServerSideConfig = () => { const isXAI = !!process.env.XAI_API_KEY; const isChatGLM = !!process.env.CHATGLM_API_KEY; const isSiliconFlow = !!process.env.SILICONFLOW_API_KEY; - // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; - // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); - // const randomIndex = Math.floor(Math.random() * apiKeys.length); - // const apiKey = apiKeys[randomIndex]; - // console.log( - // `[Server Config] using ${randomIndex + 1} of ${apiKeys.length} api key`, - // ); + + const isBedrock = + process.env.ENABLE_AWS_BEDROCK === "true" && + !!process.env.AWS_ACCESS_KEY_ID && + !!process.env.AWS_SECRET_ACCESS_KEY && + !!process.env.AWS_REGION; const allowedWebDavEndpoints = ( process.env.WHITE_WEBDAV_ENDPOINTS ?? "" ).split(","); - return { + const config = { baseUrl: process.env.BASE_URL, apiKey: getApiKey(process.env.OPENAI_API_KEY), openaiOrgId: process.env.OPENAI_ORG_ID, @@ -246,6 +245,12 @@ export const getServerSideConfig = () => { siliconFlowUrl: process.env.SILICONFLOW_URL, siliconFlowApiKey: getApiKey(process.env.SILICONFLOW_API_KEY), + isBedrock, + bedrockRegion: process.env.AWS_REGION, + bedrockAccessKeyId: process.env.AWS_ACCESS_KEY_ID, + bedrockSecretAccessKey: process.env.AWS_SECRET_ACCESS_KEY, + bedrockEndpoint: process.env.AWS_BEDROCK_ENDPOINT, + gtmId: process.env.GTM_ID, gaId: process.env.GA_ID || DEFAULT_GA_ID, @@ -266,4 +271,5 @@ export const getServerSideConfig = () => { allowedWebDavEndpoints, enableMcp: process.env.ENABLE_MCP === "true", }; + return config; }; diff --git a/app/constant.ts b/app/constant.ts index c1b135485..e1d9d78e9 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -72,6 +72,7 @@ export enum ApiPath { ChatGLM = "/api/chatglm", DeepSeek = "/api/deepseek", SiliconFlow = "/api/siliconflow", + Bedrock = "/api/bedrock", } export enum SlotID { @@ -130,6 +131,7 @@ export enum ServiceProvider { ChatGLM = "ChatGLM", DeepSeek = "DeepSeek", SiliconFlow = "SiliconFlow", + Bedrock = "Bedrock", } // Google API safety settings, see https://ai.google.dev/gemini-api/docs/safety-settings @@ -156,6 +158,7 @@ export enum ModelProvider { ChatGLM = "ChatGLM", DeepSeek = "DeepSeek", SiliconFlow = "SiliconFlow", + Bedrock = "Bedrock", } export const Stability = { @@ -266,6 +269,10 @@ export const SiliconFlow = { ListModelPath: "v1/models?&sub_type=chat", }; +export const Bedrock = { + ChatPath: "v1/chat/completions", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. diff --git a/package.json b/package.json index ceb92d7fc..722d1c51c 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "test:ci": "node --no-warnings --experimental-vm-modules $(yarn bin jest) --ci" }, "dependencies": { + "@aws-sdk/client-bedrock-runtime": "^3.782.0", "@fortaine/fetch-event-source": "^3.0.6", "@hello-pangea/dnd": "^16.5.0", "@modelcontextprotocol/sdk": "^1.0.4",