feat(bedrock): Integrate AWS Bedrock as a new LLM provider

Adds support for using models hosted on AWS Bedrock, specifically Anthropic Claude models.

Key changes:
- Added '@aws-sdk/client-bedrock-runtime' dependency.
- Updated constants, server config, and auth logic for Bedrock.
- Implemented backend API handler () to communicate with the Bedrock API, handling streaming and non-streaming responses, and formatting output to be OpenAI compatible.
- Updated dynamic API router () to dispatch requests to the Bedrock handler.
- Created frontend client () and updated client factory ().
- Updated  with necessary Bedrock environment variables (AWS keys, region, enable flag) and an example for using  to alias Bedrock models.
This commit is contained in:
AC 2025-04-06 00:41:56 +08:00
parent 48469bd8ca
commit fc9688a1f7
9 changed files with 443 additions and 19 deletions

View File

@ -57,7 +57,7 @@ DISABLE_FAST_LINK=
# (optional) # (optional)
# Default: Empty # Default: Empty
# To control custom models, use + to add a custom model, use - to hide a model, use name=displayName to customize model name, separated by comma. # To control custom models, use + to add a custom model, use - to hide a model, use name=displayName to customize model name, separated by comma.
CUSTOM_MODELS= CUSTOM_MODELS=-all,+gpt-4o-2024-11-20@openai=gpt-4o,+gpt-4o-mini@openai,+us.anthropic.claude-3-5-sonnet-20241022-v2:0@bedrock=sonnet
# (optional) # (optional)
# Default: Empty # Default: Empty
@ -81,3 +81,22 @@ SILICONFLOW_API_KEY=
### siliconflow Api url (optional) ### siliconflow Api url (optional)
SILICONFLOW_URL= SILICONFLOW_URL=
# --- AWS Bedrock Section ---
# Ensure these lines for keys either have placeholder values like below,
# are commented out entirely, or removed if they shouldn't be in the template.
# AWS Access Key for Bedrock API (Example: Use placeholder or comment out)
# AWS_ACCESS_KEY_ID=
# AWS Secret Access Key for Bedrock API (Example: Use placeholder or comment out)
# AWS_SECRET_ACCESS_KEY=
# AWS Region for Bedrock API (e.g. us-east-1, us-west-2)
AWS_REGION=
# Enable AWS Bedrock models (set to "true" to enable)
ENABLE_AWS_BEDROCK=
# Custom endpoint URL for AWS Bedrock (optional)
AWS_BEDROCK_ENDPOINT=

View File

@ -14,6 +14,7 @@ import { handle as deepseekHandler } from "../../deepseek";
import { handle as siliconflowHandler } from "../../siliconflow"; import { handle as siliconflowHandler } from "../../siliconflow";
import { handle as xaiHandler } from "../../xai"; import { handle as xaiHandler } from "../../xai";
import { handle as chatglmHandler } from "../../glm"; import { handle as chatglmHandler } from "../../glm";
import { handle as bedrockHandler } from "../../bedrock";
import { handle as proxyHandler } from "../../proxy"; import { handle as proxyHandler } from "../../proxy";
async function handle( async function handle(
@ -50,6 +51,8 @@ async function handle(
return chatglmHandler(req, { params }); return chatglmHandler(req, { params });
case ApiPath.SiliconFlow: case ApiPath.SiliconFlow:
return siliconflowHandler(req, { params }); return siliconflowHandler(req, { params });
case ApiPath.Bedrock:
return bedrockHandler(req, { params });
case ApiPath.OpenAI: case ApiPath.OpenAI:
return openaiHandler(req, { params }); return openaiHandler(req, { params });
default: default:

View File

@ -56,14 +56,6 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
// if user does not provide an api key, inject system api key // if user does not provide an api key, inject system api key
if (!apiKey) { if (!apiKey) {
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
// const systemApiKey =
// modelProvider === ModelProvider.GeminiPro
// ? serverConfig.googleApiKey
// : serverConfig.isAzure
// ? serverConfig.azureApiKey
// : serverConfig.apiKey;
let systemApiKey: string | undefined; let systemApiKey: string | undefined;
switch (modelProvider) { switch (modelProvider) {
@ -104,6 +96,11 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
case ModelProvider.SiliconFlow: case ModelProvider.SiliconFlow:
systemApiKey = serverConfig.siliconFlowApiKey; systemApiKey = serverConfig.siliconFlowApiKey;
break; break;
case ModelProvider.Bedrock:
console.log(
"[Auth] Using AWS credentials for Bedrock, no API key override.",
);
return { error: false };
case ModelProvider.GPT: case ModelProvider.GPT:
default: default:
if (req.nextUrl.pathname.includes("azure/deployments")) { if (req.nextUrl.pathname.includes("azure/deployments")) {
@ -117,7 +114,10 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
console.log("[Auth] use system api key"); console.log("[Auth] use system api key");
req.headers.set("Authorization", `Bearer ${systemApiKey}`); req.headers.set("Authorization", `Bearer ${systemApiKey}`);
} else { } else {
console.log("[Auth] admin did not provide an api key"); console.log(
"[Auth] admin did not provide an api key for provider:",
modelProvider,
);
} }
} else { } else {
console.log("[Auth] use user api key"); console.log("[Auth] use user api key");

241
app/api/bedrock/index.ts Normal file
View File

@ -0,0 +1,241 @@
import { ModelProvider, Bedrock as BedrockConfig } from "@/app/constant";
import { getServerSideConfig } from "@/app/config/server";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../auth";
import {
BedrockRuntimeClient,
InvokeModelWithResponseStreamCommand,
InvokeModelCommand,
} from "@aws-sdk/client-bedrock-runtime";
const ALLOWED_PATH = new Set([BedrockConfig.ChatPath]);
// Helper to get AWS Credentials
function getAwsCredentials() {
const config = getServerSideConfig();
if (!config.isBedrock) {
throw new Error("AWS Bedrock is not configured properly");
}
return {
accessKeyId: config.bedrockAccessKeyId as string,
secretAccessKey: config.bedrockSecretAccessKey as string,
};
}
export async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Bedrock Route] params ", params);
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const subpath = params.path.join("/");
if (!ALLOWED_PATH.has(subpath)) {
console.log("[Bedrock Route] forbidden path ", subpath);
return NextResponse.json(
{ error: true, msg: "you are not allowed to request " + subpath },
{ status: 403 },
);
}
// Auth check specifically for Bedrock (might not need header API key)
const authResult = auth(req, ModelProvider.Bedrock);
if (authResult.error) {
return NextResponse.json(authResult, { status: 401 });
}
try {
const config = getServerSideConfig();
if (!config.isBedrock) {
// This check might be redundant due to getAwsCredentials, but good practice
return NextResponse.json(
{ error: true, msg: "AWS Bedrock is not configured properly" },
{ status: 500 },
);
}
const bedrockRegion = config.bedrockRegion as string;
const bedrockEndpoint = config.bedrockEndpoint;
const client = new BedrockRuntimeClient({
region: bedrockRegion,
credentials: getAwsCredentials(),
endpoint: bedrockEndpoint || undefined,
});
const body = await req.json();
console.log("[Bedrock] request body: ", body);
const {
messages,
model,
stream = false,
temperature = 0.7,
max_tokens,
} = body;
// --- Payload formatting for Claude on Bedrock ---
const isClaudeModel = model.includes("anthropic.claude");
if (!isClaudeModel) {
return NextResponse.json(
{ error: true, msg: "Unsupported Bedrock model: " + model },
{ status: 400 },
);
}
const systemPrompts = messages.filter((msg: any) => msg.role === "system");
const userAssistantMessages = messages.filter(
(msg: any) => msg.role !== "system",
);
const payload = {
anthropic_version: "bedrock-2023-05-31",
max_tokens: max_tokens || 4096,
temperature: temperature,
messages: userAssistantMessages.map((msg: any) => ({
role: msg.role, // 'user' or 'assistant'
content:
typeof msg.content === "string"
? [{ type: "text", text: msg.content }]
: msg.content, // Assuming MultimodalContent format is compatible
})),
...(systemPrompts.length > 0 && {
system: systemPrompts.map((msg: any) => msg.content).join("\n"),
}),
};
// --- End Payload Formatting ---
if (stream) {
const command = new InvokeModelWithResponseStreamCommand({
modelId: model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(payload),
});
const response = await client.send(command);
if (!response.body) {
throw new Error("Empty response stream from Bedrock");
}
const responseBody = response.body;
const encoder = new TextEncoder();
const decoder = new TextDecoder();
const readableStream = new ReadableStream({
async start(controller) {
try {
for await (const event of responseBody) {
if (event.chunk?.bytes) {
const chunkData = JSON.parse(decoder.decode(event.chunk.bytes));
let responseText = "";
let finishReason: string | null = null;
if (
chunkData.type === "content_block_delta" &&
chunkData.delta.type === "text_delta"
) {
responseText = chunkData.delta.text || "";
} else if (chunkData.type === "message_stop") {
finishReason =
chunkData["amazon-bedrock-invocationMetrics"]
?.outputTokenCount > 0
? "stop"
: "length"; // Example logic
}
// Format as OpenAI SSE chunk
const sseData = {
id: `chatcmpl-${nanoid()}`,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: model,
choices: [
{
index: 0,
delta: { content: responseText },
finish_reason: finishReason,
},
],
};
controller.enqueue(
encoder.encode(`data: ${JSON.stringify(sseData)}\n\n`),
);
if (finishReason) {
controller.enqueue(encoder.encode("data: [DONE]\n\n"));
break; // Exit loop after stop message
}
}
}
} catch (error) {
console.error("[Bedrock] Streaming error:", error);
controller.error(error);
} finally {
controller.close();
}
},
});
return new NextResponse(readableStream, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
},
});
} else {
// Non-streaming response
const command = new InvokeModelCommand({
modelId: model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(payload),
});
const response = await client.send(command);
const responseBody = JSON.parse(new TextDecoder().decode(response.body));
// Format response to match OpenAI
const formattedResponse = {
id: `chatcmpl-${nanoid()}`,
object: "chat.completion",
created: Math.floor(Date.now() / 1000),
model: model,
choices: [
{
index: 0,
message: {
role: "assistant",
content: responseBody.content?.[0]?.text ?? "",
},
finish_reason: "stop", // Assuming stop for non-streamed
},
],
usage: {
prompt_tokens:
responseBody["amazon-bedrock-invocationMetrics"]?.inputTokenCount ??
-1,
completion_tokens:
responseBody["amazon-bedrock-invocationMetrics"]
?.outputTokenCount ?? -1,
total_tokens:
(responseBody["amazon-bedrock-invocationMetrics"]
?.inputTokenCount ?? 0) +
(responseBody["amazon-bedrock-invocationMetrics"]
?.outputTokenCount ?? 0) || -1,
},
};
return NextResponse.json(formattedResponse);
}
} catch (e) {
console.error("[Bedrock] API Handler Error:", e);
return NextResponse.json(prettyObject(e), { status: 500 });
}
}
// Need nanoid for unique IDs
import { nanoid } from "nanoid";

View File

@ -24,6 +24,7 @@ import { DeepSeekApi } from "./platforms/deepseek";
import { XAIApi } from "./platforms/xai"; import { XAIApi } from "./platforms/xai";
import { ChatGLMApi } from "./platforms/glm"; import { ChatGLMApi } from "./platforms/glm";
import { SiliconflowApi } from "./platforms/siliconflow"; import { SiliconflowApi } from "./platforms/siliconflow";
import { BedrockApi } from "./platforms/bedrock";
export const ROLES = ["system", "user", "assistant"] as const; export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number]; export type MessageRole = (typeof ROLES)[number];
@ -173,6 +174,9 @@ export class ClientApi {
case ModelProvider.SiliconFlow: case ModelProvider.SiliconFlow:
this.llm = new SiliconflowApi(); this.llm = new SiliconflowApi();
break; break;
case ModelProvider.Bedrock:
this.llm = new BedrockApi();
break;
default: default:
this.llm = new ChatGPTApi(); this.llm = new ChatGPTApi();
} }
@ -356,7 +360,7 @@ export function getHeaders(ignoreHeaders: boolean = false) {
return headers; return headers;
} }
export function getClientApi(provider: ServiceProvider): ClientApi { export function getClientApi(provider: ServiceProvider | string): ClientApi {
switch (provider) { switch (provider) {
case ServiceProvider.Google: case ServiceProvider.Google:
return new ClientApi(ModelProvider.GeminiPro); return new ClientApi(ModelProvider.GeminiPro);
@ -382,6 +386,9 @@ export function getClientApi(provider: ServiceProvider): ClientApi {
return new ClientApi(ModelProvider.ChatGLM); return new ClientApi(ModelProvider.ChatGLM);
case ServiceProvider.SiliconFlow: case ServiceProvider.SiliconFlow:
return new ClientApi(ModelProvider.SiliconFlow); return new ClientApi(ModelProvider.SiliconFlow);
case ServiceProvider.Bedrock:
case "AWS Bedrock":
return new ClientApi(ModelProvider.Bedrock);
default: default:
return new ClientApi(ModelProvider.GPT); return new ClientApi(ModelProvider.GPT);
} }

View File

@ -0,0 +1,140 @@
"use client";
import { ApiPath, Bedrock } from "@/app/constant";
import { LLMApi, ChatOptions, LLMModel, LLMUsage, SpeechOptions } from "../api";
import { getHeaders } from "../api";
import { fetch } from "@/app/utils/stream";
export class BedrockApi implements LLMApi {
path(path: string): string {
// Route requests to our backend handler
const apiPath = `${ApiPath.Bedrock}/${path}`;
console.log("[BedrockApi] Constructed API path:", apiPath);
return apiPath;
}
async chat(options: ChatOptions) {
const messages = options.messages;
const modelConfig = options.config;
const controller = new AbortController();
options.onController?.(controller);
try {
const chatPath = this.path(Bedrock.ChatPath);
console.log("[BedrockApi] Requesting path:", chatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify({
model: modelConfig.model,
messages,
temperature: modelConfig.temperature,
stream: !!modelConfig.stream,
max_tokens: 4096, // Example: You might want to make this configurable
}),
signal: controller.signal,
headers: getHeaders(), // getHeaders should handle Bedrock (no auth needed)
};
console.log("[BedrockApi] Request payload (excluding messages):", {
model: modelConfig.model,
temperature: modelConfig.temperature,
stream: !!modelConfig.stream,
});
// Handle stream response
if (modelConfig.stream) {
const response = await fetch(chatPath, chatPayload);
const reader = response.body?.getReader();
const decoder = new TextDecoder();
let messageBuffer = "";
if (!reader) {
throw new Error("Response body reader is not available");
}
while (true) {
// Loop until stream is done
const { done, value } = await reader.read();
if (done) break;
const text = decoder.decode(value, { stream: true }); // Decode chunk
const lines = text.split("\n");
for (const line of lines) {
if (!line.startsWith("data:")) continue;
const jsonData = line.substring("data:".length).trim();
if (jsonData === "[DONE]") break; // End of stream
if (!jsonData) continue;
try {
const data = JSON.parse(jsonData);
const content = data.choices?.[0]?.delta?.content ?? "";
const finishReason = data.choices?.[0]?.finish_reason;
if (content) {
messageBuffer += content;
options.onUpdate?.(messageBuffer, content);
}
if (finishReason) {
// Potentially handle finish reason if needed
console.log(
"[BedrockApi] Stream finished with reason:",
finishReason,
);
break; // Exit inner loop on finish signal within a chunk
}
} catch (e) {
console.error(
"[BedrockApi] Error parsing stream chunk:",
jsonData,
e,
);
}
}
}
reader.releaseLock(); // Release reader lock
options.onFinish(messageBuffer, response);
} else {
// Handle non-streaming response
const response = await fetch(chatPath, chatPayload);
if (!response.ok) {
const errorBody = await response.text();
console.error(
"[BedrockApi] Non-stream error response:",
response.status,
errorBody,
);
throw new Error(
`Request failed with status ${response.status}: ${errorBody}`,
);
}
const responseJson = await response.json();
const content = responseJson.choices?.[0]?.message?.content ?? "";
options.onFinish(content, response);
}
} catch (e) {
console.error("[BedrockApi] Chat request failed:", e);
options.onError?.(e as Error);
}
}
async usage(): Promise<LLMUsage> {
// Bedrock usage reporting might require separate implementation if available
return {
used: 0,
total: Number.MAX_SAFE_INTEGER, // Indicate no limit or unknown
};
}
async models(): Promise<LLMModel[]> {
// Fetching models dynamically from Bedrock is complex and usually not needed
// Rely on the hardcoded models in constant.ts
return [];
}
async speech(options: SpeechOptions): Promise<ArrayBuffer> {
// Implement if Bedrock TTS is needed
throw new Error("Speech synthesis not supported for Bedrock yet");
}
}

View File

@ -163,19 +163,18 @@ export const getServerSideConfig = () => {
const isXAI = !!process.env.XAI_API_KEY; const isXAI = !!process.env.XAI_API_KEY;
const isChatGLM = !!process.env.CHATGLM_API_KEY; const isChatGLM = !!process.env.CHATGLM_API_KEY;
const isSiliconFlow = !!process.env.SILICONFLOW_API_KEY; const isSiliconFlow = !!process.env.SILICONFLOW_API_KEY;
// const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? "";
// const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); const isBedrock =
// const randomIndex = Math.floor(Math.random() * apiKeys.length); process.env.ENABLE_AWS_BEDROCK === "true" &&
// const apiKey = apiKeys[randomIndex]; !!process.env.AWS_ACCESS_KEY_ID &&
// console.log( !!process.env.AWS_SECRET_ACCESS_KEY &&
// `[Server Config] using ${randomIndex + 1} of ${apiKeys.length} api key`, !!process.env.AWS_REGION;
// );
const allowedWebDavEndpoints = ( const allowedWebDavEndpoints = (
process.env.WHITE_WEBDAV_ENDPOINTS ?? "" process.env.WHITE_WEBDAV_ENDPOINTS ?? ""
).split(","); ).split(",");
return { const config = {
baseUrl: process.env.BASE_URL, baseUrl: process.env.BASE_URL,
apiKey: getApiKey(process.env.OPENAI_API_KEY), apiKey: getApiKey(process.env.OPENAI_API_KEY),
openaiOrgId: process.env.OPENAI_ORG_ID, openaiOrgId: process.env.OPENAI_ORG_ID,
@ -246,6 +245,12 @@ export const getServerSideConfig = () => {
siliconFlowUrl: process.env.SILICONFLOW_URL, siliconFlowUrl: process.env.SILICONFLOW_URL,
siliconFlowApiKey: getApiKey(process.env.SILICONFLOW_API_KEY), siliconFlowApiKey: getApiKey(process.env.SILICONFLOW_API_KEY),
isBedrock,
bedrockRegion: process.env.AWS_REGION,
bedrockAccessKeyId: process.env.AWS_ACCESS_KEY_ID,
bedrockSecretAccessKey: process.env.AWS_SECRET_ACCESS_KEY,
bedrockEndpoint: process.env.AWS_BEDROCK_ENDPOINT,
gtmId: process.env.GTM_ID, gtmId: process.env.GTM_ID,
gaId: process.env.GA_ID || DEFAULT_GA_ID, gaId: process.env.GA_ID || DEFAULT_GA_ID,
@ -266,4 +271,5 @@ export const getServerSideConfig = () => {
allowedWebDavEndpoints, allowedWebDavEndpoints,
enableMcp: process.env.ENABLE_MCP === "true", enableMcp: process.env.ENABLE_MCP === "true",
}; };
return config;
}; };

View File

@ -72,6 +72,7 @@ export enum ApiPath {
ChatGLM = "/api/chatglm", ChatGLM = "/api/chatglm",
DeepSeek = "/api/deepseek", DeepSeek = "/api/deepseek",
SiliconFlow = "/api/siliconflow", SiliconFlow = "/api/siliconflow",
Bedrock = "/api/bedrock",
} }
export enum SlotID { export enum SlotID {
@ -130,6 +131,7 @@ export enum ServiceProvider {
ChatGLM = "ChatGLM", ChatGLM = "ChatGLM",
DeepSeek = "DeepSeek", DeepSeek = "DeepSeek",
SiliconFlow = "SiliconFlow", SiliconFlow = "SiliconFlow",
Bedrock = "Bedrock",
} }
// Google API safety settings, see https://ai.google.dev/gemini-api/docs/safety-settings // Google API safety settings, see https://ai.google.dev/gemini-api/docs/safety-settings
@ -156,6 +158,7 @@ export enum ModelProvider {
ChatGLM = "ChatGLM", ChatGLM = "ChatGLM",
DeepSeek = "DeepSeek", DeepSeek = "DeepSeek",
SiliconFlow = "SiliconFlow", SiliconFlow = "SiliconFlow",
Bedrock = "Bedrock",
} }
export const Stability = { export const Stability = {
@ -266,6 +269,10 @@ export const SiliconFlow = {
ListModelPath: "v1/models?&sub_type=chat", ListModelPath: "v1/models?&sub_type=chat",
}; };
export const Bedrock = {
ChatPath: "v1/chat/completions",
};
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
// export const DEFAULT_SYSTEM_TEMPLATE = ` // export const DEFAULT_SYSTEM_TEMPLATE = `
// You are ChatGPT, a large language model trained by {{ServiceProvider}}. // You are ChatGPT, a large language model trained by {{ServiceProvider}}.

View File

@ -21,6 +21,7 @@
"test:ci": "node --no-warnings --experimental-vm-modules $(yarn bin jest) --ci" "test:ci": "node --no-warnings --experimental-vm-modules $(yarn bin jest) --ci"
}, },
"dependencies": { "dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.782.0",
"@fortaine/fetch-event-source": "^3.0.6", "@fortaine/fetch-event-source": "^3.0.6",
"@hello-pangea/dnd": "^16.5.0", "@hello-pangea/dnd": "^16.5.0",
"@modelcontextprotocol/sdk": "^1.0.4", "@modelcontextprotocol/sdk": "^1.0.4",