diff --git a/.env.template b/.env.template index b2a0438d9..bed71a9c3 100644 --- a/.env.template +++ b/.env.template @@ -54,10 +54,18 @@ ANTHROPIC_API_KEY= ### anthropic claude Api version. (optional) ANTHROPIC_API_VERSION= - - ### anthropic claude Api url (optional) ANTHROPIC_URL= +# AWS Bedrock API Key.(optional) +AWS_API_KEY= + +# AWS Bedrock API url (optional) +AWS_URL= + +# AWS Bedrock API version (optional) +AWS_API_VERSION= + + ### (optional) WHITE_WEBDEV_ENDPOINTS= \ No newline at end of file diff --git a/app/api/auth.ts b/app/api/auth.ts index b750f2d17..f61e3f81f 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.Claude: systemApiKey = serverConfig.anthropicApiKey; break; + case ModelProvider.Bedrock: + systemApiKey = serverConfig.awsApiKey; + break; case ModelProvider.GPT: default: if (serverConfig.isAzure) { diff --git a/app/client/api.ts b/app/client/api.ts index 7bee546b4..72dbd516e 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -9,6 +9,7 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; +import { BedrockApi } from "./platforms/aws"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -70,7 +71,7 @@ export abstract class LLMApi { abstract models(): Promise; } -type ProviderName = "openai" | "azure" | "claude" | "palm"; +type ProviderName = "aws" | "openai" | "azure" | "claude" | "palm"; interface Model { name: string; @@ -102,6 +103,9 @@ export class ClientApi { case ModelProvider.Claude: this.llm = new ClaudeApi(); break; + case ModelProvider.Bedrock: + this.llm = new BedrockApi(); + break; default: this.llm = new ChatGPTApi(); } @@ -162,11 +166,15 @@ export function getHeaders() { const modelConfig = useChatStore.getState().currentSession().mask.modelConfig; const isGoogle = modelConfig.model.startsWith("gemini"); const isAzure = accessStore.provider === ServiceProvider.Azure; + const isAWS = accessStore.provider === ServiceProvider.AWS; const authHeader = isAzure ? "api-key" : "Authorization"; + const apiKey = isGoogle ? accessStore.googleApiKey : isAzure ? accessStore.azureApiKey + : isAWS + ? accessStore.awsApiKey : accessStore.openaiApiKey; const clientConfig = getClientConfig(); const makeBearer = (s: string) => `${isAzure ? "" : "Bearer "}${s.trim()}`; diff --git a/app/client/platforms/aws.ts b/app/client/platforms/aws.ts new file mode 100644 index 000000000..118c9e527 --- /dev/null +++ b/app/client/platforms/aws.ts @@ -0,0 +1,330 @@ +"use client"; +import { DEFAULT_MODELS, AWS, REQUEST_TIMEOUT_MS } from "@/app/constant"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; + +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + LLMUsage, + MultimodalContent, +} from "../api"; +import Locale from "../../locales"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "@/app/utils/format"; +import { getClientConfig } from "@/app/config/client"; +import { + getMessageTextContent, + getMessageImages, + isVisionModel, +} from "@/app/utils"; + +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + +export class BedrockApi implements LLMApi { + private disableListModels = true; + + path(path: string): string { + const accessStore = useAccessStore.getState(); + + if (!accessStore.awsUrl) { + throw Error("Please set your access url."); + } + let baseUrl = accessStore.awsUrl; + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + console.log("[Proxy Endpoint] ", baseUrl, path); + return [baseUrl, path].join("/"); + } + + extractMessage(res: any) { + return res.choices?.at(0)?.message?.content ?? ""; + } + + async chat(options: ChatOptions) { + const visionModel = isVisionModel(options.config.model); + const messages = options.messages.map((v) => ({ + role: v.role, + content: visionModel ? v.content : getMessageTextContent(v), + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + + const requestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + max_tokens: modelConfig.max_tokens, + }; + + // add max_tokens to vision model + if (visionModel) { + Object.defineProperty(requestPayload, "max_tokens", { + enumerable: true, + configurable: true, + writable: true, + value: modelConfig.max_tokens, + }); + } + + console.log("[Request] aws bedrock payload: ", requestPayload); + + const shouldStream = !!options.config.stream; + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = this.path(AWS.ChatPath); + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + if (shouldStream) { + let responseText = ""; + let remainText = ""; + let finished = false; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log("[AWS] request response content type: ", contentType); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text) as { + choices: Array<{ + delta: { + content: string; + }; + }>; + }; + const delta = json.choices[0]?.delta?.content; + if (delta) { + remainText += delta; + } + } catch (e) { + console.error("[Request] parse error", text); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } else { + const res = await fetch(chatPath, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + } catch (e) { + console.log("[Request] failed to make a chat request", e); + options.onError?.(e as Error); + } + } + + async usage() { + const formatDate = (d: Date) => + `${d.getFullYear()}-${(d.getMonth() + 1).toString().padStart(2, "0")}-${d + .getDate() + .toString() + .padStart(2, "0")}`; + const ONE_DAY = 1 * 24 * 60 * 60 * 1000; + const now = new Date(); + const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1); + const startDate = formatDate(startOfMonth); + const endDate = formatDate(new Date(Date.now() + ONE_DAY)); + + const [used, subs] = await Promise.all([ + fetch( + this.path( + `${AWS.UsagePath}?start_date=${startDate}&end_date=${endDate}`, + ), + { + method: "GET", + headers: getHeaders(), + }, + ), + fetch(this.path(AWS.SubsPath), { + method: "GET", + headers: getHeaders(), + }), + ]); + + if (used.status === 401) { + throw new Error(Locale.Error.Unauthorized); + } + + if (!used.ok || !subs.ok) { + throw new Error("Failed to query usage from openai"); + } + + const response = (await used.json()) as { + total_usage?: number; + error?: { + type: string; + message: string; + }; + }; + + const total = (await subs.json()) as { + hard_limit_usd?: number; + }; + + if (response.error && response.error.type) { + throw Error(response.error.message); + } + + if (response.total_usage) { + response.total_usage = Math.round(response.total_usage) / 100; + } + + if (total.hard_limit_usd) { + total.hard_limit_usd = Math.round(total.hard_limit_usd * 100) / 100; + } + + return { + used: response.total_usage, + total: total.hard_limit_usd, + } as LLMUsage; + } + + async models(): Promise { + if (this.disableListModels) { + return DEFAULT_MODELS.slice(); + } + + const res = await fetch(this.path(AWS.ListModelPath), { + method: "GET", + headers: { + ...getHeaders(), + }, + }); + + const resJson = (await res.json()) as OpenAIListModelResponse; + const chatModels = resJson.data; + console.log("[Models]", chatModels); + + if (!chatModels) { + return []; + } + + return chatModels.map((m) => ({ + name: m.id, + displayName: m.id, + available: true, + provider: { + id: "bedrock", + providerName: "Bedrock", + providerType: "bedrock", + }, + })); + } +} +export { AWS }; diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 20e240d93..07c8ec277 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -317,7 +317,8 @@ export function PreviewActions(props: { if (config.modelConfig.model.startsWith("gemini")) { api = new ClientApi(ModelProvider.GeminiPro); } else if (identifyDefaultClaudeModel(config.modelConfig.model)) { - api = new ClientApi(ModelProvider.Claude); + //api = new ClientApi(ModelProvider.Claude); + api = new ClientApi(ModelProvider.Bedrock); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/components/home.tsx b/app/components/home.tsx index ffac64fda..dbf8acca6 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -175,7 +175,8 @@ export function useLoadData() { if (config.modelConfig.model.startsWith("gemini")) { api = new ClientApi(ModelProvider.GeminiPro); } else if (identifyDefaultClaudeModel(config.modelConfig.model)) { - api = new ClientApi(ModelProvider.Claude); + //api = new ClientApi(ModelProvider.Claude); + api = new ClientApi(ModelProvider.Bedrock); } else { api = new ClientApi(ModelProvider.GPT); } diff --git a/app/components/settings.tsx b/app/components/settings.tsx index db08b48a9..c385613a4 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -1187,6 +1187,46 @@ export function Settings() { )} + {accessStore.provider === ServiceProvider.AWS && ( + <> + + + accessStore.update( + (access) => + (access.awsUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.awsApiKey = e.currentTarget.value), + ); + }} + /> + + + )} )} diff --git a/app/components/sidebar.tsx b/app/components/sidebar.tsx index 69b2e71f8..4b867224e 100644 --- a/app/components/sidebar.tsx +++ b/app/components/sidebar.tsx @@ -216,11 +216,11 @@ export function SideBar(props: { className?: string }) { } shadow /> -
+ {/*
} shadow /> -
+
*/}
{ const isAzure = !!process.env.AZURE_URL; const isGoogle = !!process.env.GOOGLE_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY; + const isAWS = !!process.env.AWS_API_KEY; // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); @@ -124,6 +125,10 @@ export const getServerSideConfig = () => { anthropicApiVersion: process.env.ANTHROPIC_API_VERSION, anthropicUrl: process.env.ANTHROPIC_URL, + isAWS, + awsApiKey: getApiKey(process.env.AWS_API_KEY), + awsUrl: process.env.AWS_URL, + gtmId: process.env.GTM_ID, needCode: ACCESS_CODES.size > 0, diff --git a/app/constant.ts b/app/constant.ts index 411e48150..8e77a65bd 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -1,4 +1,4 @@ -export const OWNER = "Yidadaa"; +export const OWNER = "Ryder Tsui"; export const REPO = "ChatGPT-Next-Web"; export const REPO_URL = `https://github.com/${OWNER}/${REPO}`; export const ISSUE_URL = `https://github.com/${OWNER}/${REPO}/issues`; @@ -70,12 +70,14 @@ export enum ServiceProvider { Azure = "Azure", Google = "Google", Anthropic = "Anthropic", + AWS = "AWS", } export enum ModelProvider { GPT = "GPT", GeminiPro = "GeminiPro", Claude = "Claude", + Bedrock = "Bedrock", } export const Anthropic = { @@ -101,6 +103,12 @@ export const Google = { ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, }; +export const AWS = { + ExampleEndpoint: "http://localhost:8866", + ChatPath: "v1/chat/completions", + ListModelPath: "v1/models", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. @@ -168,6 +176,14 @@ const anthropicModels = [ "claude-3-haiku-20240307", ]; +const bedrockModels = [ + "llama3-70b", + "llama3-8b", + "claude-3-haiku", + "claude-3-sonnet", + "claude-3-opus", +]; + export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, @@ -196,6 +212,15 @@ export const DEFAULT_MODELS = [ providerType: "anthropic", }, })), + ...bedrockModels.map((name) => ({ + name, + available: true, + provider: { + id: "AWS", + providerName: "Bedrock", + providerType: "bedrock", + }, + })), ] as const; export const CHAT_PAGE_SIZE = 15; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 2ff94e32d..fe04152e3 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -347,6 +347,21 @@ const cn = { SubTitle: "选择一个特定的 API 版本", }, }, + AWS: { + Endpoint: { + Title: "接口地址", + SubTitle: "样例:http://localhost:8866", + }, + ApiKey: { + Title: "API Key", + SubTitle: "API Key", + Placeholder: "输入您的 AWS API 密钥", + }, + ApiVerion: { + Title: "接口版本 (aws api version)", + SubTitle: "选择指定的部分版本", + }, + }, CustomModel: { Title: "自定义模型名", SubTitle: "增加自定义模型可选项,使用英文逗号隔开", diff --git a/app/store/access.ts b/app/store/access.ts index 64909609e..fe4c77679 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -42,6 +42,11 @@ const DEFAULT_ACCESS_STATE = { anthropicApiVersion: "2023-06-01", anthropicUrl: "", + //AWS + awsUrl: "", + awsApiKey: "", + awsApiVersion: "", + // server config needCode: true, hideUserApiKey: false, diff --git a/app/store/chat.ts b/app/store/chat.ts index 27a7114a3..94f99b5ed 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -366,8 +366,9 @@ export const useChatStore = createPersistStore( var api: ClientApi; if (modelConfig.model.startsWith("gemini")) { api = new ClientApi(ModelProvider.GeminiPro); - } else if (identifyDefaultClaudeModel(modelConfig.model)) { - api = new ClientApi(ModelProvider.Claude); + } else if (modelConfig.model.startsWith("claude")) { + api = new ClientApi(ModelProvider.Bedrock); + //api = new ClientApi(ModelProvider.Claude); } else { api = new ClientApi(ModelProvider.GPT); } @@ -551,7 +552,8 @@ export const useChatStore = createPersistStore( if (modelConfig.model.startsWith("gemini")) { api = new ClientApi(ModelProvider.GeminiPro); } else if (identifyDefaultClaudeModel(modelConfig.model)) { - api = new ClientApi(ModelProvider.Claude); + //api = new ClientApi(ModelProvider.Claude); + api = new ClientApi(ModelProvider.Bedrock); } else { api = new ClientApi(ModelProvider.GPT); }