fix: support azure

This commit is contained in:
Hk-Gosuto
2023-12-25 19:32:18 +08:00
parent fa2e046285
commit 24de1bb77a
7 changed files with 77 additions and 30 deletions

View File

@@ -60,7 +60,6 @@ export async function requestOpenai(req: NextRequest) {
path = makeAzurePath(path, serverConfig.azureApiVersion);
}
const fetchUrl = `${baseUrl}/${path}`;
const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
@@ -78,6 +77,12 @@ export async function requestOpenai(req: NextRequest) {
duplex: "half",
signal: controller.signal,
};
const clonedBody = await req.text();
const jsonBody = JSON.parse(clonedBody) as { model?: string };
if (serverConfig.isAzure) {
baseUrl = `${baseUrl}/${jsonBody.model}`;
}
const fetchUrl = `${baseUrl}/${path}`;
// #1815 try to refuse gpt4 request
if (serverConfig.customModels && req.body) {
@@ -86,11 +91,10 @@ export async function requestOpenai(req: NextRequest) {
DEFAULT_MODELS,
serverConfig.customModels,
);
const clonedBody = await req.text();
// const clonedBody = await req.text();
// const jsonBody = JSON.parse(clonedBody) as { model?: string };
fetchOptions.body = clonedBody;
const jsonBody = JSON.parse(clonedBody) as { model?: string };
// not undefined and is false
if (modelTable[jsonBody?.model ?? ""].available === false) {
return NextResponse.json(

View File

@@ -8,7 +8,7 @@ import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents";
import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
@@ -16,6 +16,7 @@ import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { DynamicTool, Tool } from "langchain/tools";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { useAccessStore } from "@/app/store";
export interface RequestMessage {
role: string;
@@ -24,6 +25,8 @@ export interface RequestMessage {
export interface RequestBody {
messages: RequestMessage[];
isAzure: boolean;
azureApiVersion?: string;
model: string;
stream?: boolean;
temperature: number;
@@ -152,10 +155,10 @@ export class AgentApi {
async getOpenAIApiKey(token: string) {
const serverConfig = getServerSideConfig();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
if (isApiKey && token) {
apiKey = token;
}
return apiKey;
@@ -179,27 +182,31 @@ export class AgentApi {
customTools: any[],
) {
try {
let useTools = reqBody.useTools ?? [];
const serverConfig = getServerSideConfig();
// const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
const isAzure = reqBody.isAzure || serverConfig.isAzure;
const authHeaderName = isAzure ? "api-key" : "Authorization";
const authToken = req.headers.get(authHeaderName) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let useTools = reqBody.useTools ?? [];
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
apiKey = token;
}
let apiKey = await this.getOpenAIApiKey(token);
if (isAzure) apiKey = token;
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
)
) {
baseUrl = reqBody.baseUrl;
if (!baseUrl.endsWith("/v1"))
}
if (!isAzure && !baseUrl.endsWith("/v1")) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
}
if (!reqBody.isAzure && serverConfig.isAzure) {
baseUrl = serverConfig.azureUrl || baseUrl;
}
console.log("[baseUrl]", baseUrl);
var handler = await this.getHandler(reqBody);
@@ -281,7 +288,7 @@ export class AgentApi {
chatHistory: new ChatMessageHistory(pastMessages),
});
const llm = new ChatOpenAI(
let llm = new ChatOpenAI(
{
modelName: reqBody.model,
openAIApiKey: apiKey,
@@ -293,6 +300,23 @@ export class AgentApi {
},
{ basePath: baseUrl },
);
if (reqBody.isAzure || serverConfig.isAzure) {
llm = new ChatOpenAI({
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: reqBody.isAzure
? reqBody.azureApiVersion
: serverConfig.azureApiVersion,
azureOpenAIApiDeploymentName: reqBody.model,
azureOpenAIBasePath: baseUrl,
});
}
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps,