Enhance encryption security with additional safeguards.

This commit is contained in:
glay
2024-12-08 23:28:59 +08:00
parent 26b9fa97cd
commit f5ae086d3c
6 changed files with 245 additions and 141 deletions

View File

@@ -20,6 +20,7 @@ import {
} from "@/app/utils/aws";
import { prettyObject } from "@/app/utils/format";
import Locale from "@/app/locales";
import { encrypt } from "@/app/utils/aws";
const ClaudeMapper = {
assistant: "assistant",
@@ -41,6 +42,66 @@ interface Tool {
parameters?: any;
};
}
const isApp = !!getClientConfig()?.isApp;
// const isApp = true;
async function getBedrockHeaders(
modelId: string,
chatPath: string,
finalRequestBody: any,
shouldStream: boolean,
): Promise<Record<string, string>> {
const accessStore = useAccessStore.getState();
const bedrockHeaders = isApp
? await sign({
method: "POST",
url: chatPath,
region: accessStore.awsRegion,
accessKeyId: accessStore.awsAccessKey,
secretAccessKey: accessStore.awsSecretKey,
body: finalRequestBody,
service: "bedrock",
headers: {},
isStreaming: shouldStream,
})
: getHeaders();
if (!isApp) {
const { awsRegion, awsAccessKey, awsSecretKey, encryptionKey } =
accessStore;
const bedrockHeadersConfig = {
XModelID: modelId,
XEncryptionKey: encryptionKey,
ShouldStream: String(shouldStream),
Authorization: await createAuthHeader(
awsRegion,
awsAccessKey,
awsSecretKey,
encryptionKey,
),
};
Object.assign(bedrockHeaders, bedrockHeadersConfig);
}
return bedrockHeaders;
}
// Helper function to create Authorization header
async function createAuthHeader(
region: string,
accessKey: string,
secretKey: string,
encryptionKey: string,
): Promise<string> {
const encryptedValues = await Promise.all([
encrypt(region, encryptionKey),
encrypt(accessKey, encryptionKey),
encrypt(secretKey, encryptionKey),
]);
return `Bearer ${encryptedValues.join(":")}`;
}
export class BedrockApi implements LLMApi {
speech(options: SpeechOptions): Promise<ArrayBuffer> {
@@ -343,32 +404,11 @@ export class BedrockApi implements LLMApi {
let finalRequestBody = this.formatRequestBody(messages, modelConfig);
try {
const isApp = !!getClientConfig()?.isApp;
// const isApp = true;
const bedrockAPIPath = `${BEDROCK_BASE_URL}/model/${
modelConfig.model
}/invoke${shouldStream ? "-with-response-stream" : ""}`;
const chatPath = isApp ? bedrockAPIPath : ApiPath.Bedrock + "/chat";
const headers = isApp
? await sign({
method: "POST",
url: chatPath,
region: accessStore.awsRegion,
accessKeyId: accessStore.awsAccessKey,
secretAccessKey: accessStore.awsSecretKey,
body: finalRequestBody,
service: "bedrock",
isStreaming: shouldStream,
})
: getHeaders();
if (!isApp) {
headers.XModelID = modelConfig.model;
headers.XEncryptionKey = accessStore.encryptionKey;
headers.ShouldStream = shouldStream + "";
}
if (process.env.NODE_ENV !== "production") {
console.debug("[Bedrock Client] Request:", {
path: chatPath,
@@ -385,9 +425,9 @@ export class BedrockApi implements LLMApi {
useChatStore.getState().currentSession().mask?.plugin || [],
);
return bedrockStream(
modelConfig.model,
chatPath,
finalRequestBody,
headers,
funcs,
controller,
// processToolMessage, include tool_calls message and tool call results
@@ -513,9 +553,15 @@ export class BedrockApi implements LLMApi {
try {
controller.signal.onabort = () =>
options.onFinish("", new Response(null, { status: 400 }));
const newHeaders = await getBedrockHeaders(
modelConfig.model,
chatPath,
JSON.stringify(finalRequestBody),
shouldStream,
);
const res = await fetch(chatPath, {
method: "POST",
headers: headers,
headers: newHeaders,
body: JSON.stringify(finalRequestBody),
});
const contentType = res.headers.get("content-type");
@@ -547,9 +593,9 @@ export class BedrockApi implements LLMApi {
}
function bedrockStream(
modelId: string,
chatPath: string,
requestPayload: any,
headers: any,
funcs: Record<string, Function>,
controller: AbortController,
processToolMessage: (
@@ -655,7 +701,7 @@ function bedrockStream(
setTimeout(() => {
console.debug("[BedrockAPI for toolCallResult] restart");
running = false;
bedrockChatApi(chatPath, headers, requestPayload);
bedrockChatApi(modelId, chatPath, requestPayload, true);
}, 60);
});
}
@@ -671,19 +717,26 @@ function bedrockStream(
controller.signal.onabort = finish;
async function bedrockChatApi(
modelId: string,
chatPath: string,
headers: any,
requestPayload: any,
shouldStream: boolean,
) {
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
const newHeaders = await getBedrockHeaders(
modelId,
chatPath,
JSON.stringify(requestPayload),
shouldStream,
);
try {
const res = await fetch(chatPath, {
method: "POST",
headers,
headers: newHeaders,
body: JSON.stringify(requestPayload),
redirect: "manual",
// @ts-ignore
@@ -792,5 +845,5 @@ function bedrockStream(
}
console.debug("[BedrockAPI] start");
bedrockChatApi(chatPath, headers, requestPayload);
bedrockChatApi(modelId, chatPath, requestPayload, true);
}