feat: claude function call

This commit is contained in:
Hk-Gosuto
2024-08-17 13:05:49 +00:00
parent 0a643dc71d
commit 8c5e92d66a
5 changed files with 119 additions and 80 deletions

View File

@@ -10,7 +10,12 @@ import {
createToolCallingAgent,
createReactAgent,
} from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import {
ACCESS_CODE_PREFIX,
ANTHROPIC_BASE_URL,
OPENAI_BASE_URL,
ServiceProvider,
} from "@/app/constant";
// import * as langchainTools from "langchain/tools";
import * as langchainTools from "@/app/api/langchain-tools/langchian-tool-index";
@@ -33,7 +38,7 @@ import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { ChatAnthropic } from "@langchain/anthropic";
import {
BaseMessage,
@@ -45,6 +50,7 @@ import {
} from "@langchain/core/messages";
import { MultimodalContent } from "@/app/client/api";
import { GoogleCustomSearch } from "@/app/api/langchain-tools/langchian-tool-index";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
export interface RequestMessage {
role: string;
@@ -202,29 +208,81 @@ export class AgentApi {
});
}
async getOpenAIApiKey(token: string) {
getApiKey(token: string, provider: ServiceProvider) {
const serverConfig = getServerSideConfig();
const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let apiKey = serverConfig.apiKey;
if (isApiKey && token) {
apiKey = token;
return token;
}
return apiKey;
if (provider === ServiceProvider.OpenAI) return serverConfig.apiKey;
if (provider === ServiceProvider.Anthropic)
return serverConfig.anthropicApiKey;
throw new Error("Unsupported provider");
}
async getOpenAIBaseUrl(reqBaseUrl: string | undefined) {
getBaseUrl(reqBaseUrl: string | undefined, provider: ServiceProvider) {
const serverConfig = getServerSideConfig();
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
let baseUrl = "";
if (provider === ServiceProvider.OpenAI) {
baseUrl = OPENAI_BASE_URL;
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
}
if (provider === ServiceProvider.Anthropic) {
baseUrl = ANTHROPIC_BASE_URL;
if (serverConfig.anthropicUrl) baseUrl = serverConfig.anthropicUrl;
}
if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://"))
baseUrl = reqBaseUrl;
if (!baseUrl.endsWith("/v1"))
if (!baseUrl.endsWith("/v1") && provider === ServiceProvider.OpenAI)
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[openai baseUrl]", baseUrl);
return baseUrl;
}
getToolBaseLanguageModel(
reqBody: RequestBody,
apiKey: string,
baseUrl: string,
) {
if (reqBody.provider === ServiceProvider.Anthropic) {
return new ChatAnthropic({
temperature: 0,
modelName: reqBody.model,
apiKey: apiKey,
clientOptions: {
baseURL: baseUrl,
},
});
}
return new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
getToolEmbeddings(reqBody: RequestBody, apiKey: string, baseUrl: string) {
if (reqBody.provider === ServiceProvider.Anthropic) {
if (process.env.OLLAMA_BASE_URL) {
return new OllamaEmbeddings({
model: process.env.RAG_EMBEDDING_MODEL,
baseUrl: process.env.OLLAMA_BASE_URL,
});
} else {
return null;
}
}
return new OpenAIEmbeddings(
{
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) {
const serverConfig = getServerSideConfig();
if (reqBody.isAzure || serverConfig.isAzure) {
@@ -266,7 +324,6 @@ export class AgentApi {
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
// maxTokens: 1024,
clientOptions: {
baseURL: baseUrl,
},
@@ -300,22 +357,9 @@ export class AgentApi {
const authToken = req.headers.get(authHeaderName) ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
let apiKey = await this.getOpenAIApiKey(token);
let apiKey = this.getApiKey(token, reqBody.provider);
if (isAzure) apiKey = token;
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
) {
baseUrl = reqBody.baseUrl;
}
if (
reqBody.provider === ServiceProvider.OpenAI &&
!baseUrl.endsWith("/v1")
) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
}
let baseUrl = this.getBaseUrl(reqBody.baseUrl, reqBody.provider);
if (!reqBody.isAzure && serverConfig.isAzure) {
baseUrl = serverConfig.azureUrl || baseUrl;
}