feat: support claude function call

This commit is contained in:
Hk-Gosuto
2024-08-04 08:32:26 +00:00
parent a20c57b0e8
commit a0fc9bd316
9 changed files with 293 additions and 436 deletions

View File

@@ -2,8 +2,6 @@ import { NextRequest, NextResponse } from "next/server";
import { auth } from "@/app/api/auth";
import { ACCESS_CODE_PREFIX, ModelProvider } from "@/app/constant";
import { OpenAIEmbeddings } from "@langchain/openai";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "@langchain/pinecone";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
import { getServerSideConfig } from "@/app/config/server";

View File

@@ -13,8 +13,6 @@ import { OpenAIWhisperAudio } from "langchain/document_loaders/fs/openai_whisper
// import { PPTXLoader } from "langchain/document_loaders/fs/pptx";
import { SRTLoader } from "langchain/document_loaders/fs/srt";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "@langchain/pinecone";
import { getServerSideConfig } from "@/app/config/server";
import { FileInfo } from "@/app/client/platforms/utils";
import mime from "mime";

View File

@@ -4,7 +4,11 @@ import { getServerSideConfig } from "@/app/config/server";
import { BaseCallbackHandler } from "@langchain/core/callbacks/base";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { AgentExecutor, AgentStep } from "langchain/agents";
import {
AgentExecutor,
AgentStep,
createToolCallingAgent,
} from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
// import * as langchainTools from "langchain/tools";
@@ -29,6 +33,7 @@ import {
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import { ChatAnthropic } from "@langchain/anthropic";
import {
BaseMessage,
FunctionMessage,
@@ -61,6 +66,7 @@ export interface RequestBody {
maxIterations: number;
returnIntermediateSteps: boolean;
useTools: (undefined | string)[];
provider: ServiceProvider;
}
export class ResponseBody {
@@ -218,6 +224,50 @@ export class AgentApi {
return baseUrl;
}
getLLM(reqBody: RequestBody, apiKey: string, baseUrl: string) {
const serverConfig = getServerSideConfig();
if (reqBody.isAzure || serverConfig.isAzure)
return new ChatOpenAI({
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: reqBody.isAzure
? reqBody.azureApiVersion
: serverConfig.azureApiVersion,
azureOpenAIApiDeploymentName: reqBody.model,
azureOpenAIBasePath: baseUrl,
});
if (reqBody.provider === ServiceProvider.OpenAI)
return new ChatOpenAI(
{
modelName: reqBody.model,
openAIApiKey: apiKey,
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
},
{ basePath: baseUrl },
);
if (reqBody.provider === ServiceProvider.Anthropic)
return new ChatAnthropic({
model: reqBody.model,
apiKey: apiKey,
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
// maxTokens: 1024,
clientOptions: {
baseURL: baseUrl,
},
});
throw new Error("Unsupported model providers");
}
async getApiHandler(
req: NextRequest,
reqBody: RequestBody,
@@ -344,87 +394,38 @@ export class AgentApi {
pastMessages.push(new AIMessage(message.content));
});
let llm = new ChatOpenAI(
{
modelName: reqBody.model,
openAIApiKey: apiKey,
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
},
{ basePath: baseUrl },
);
let llm = this.getLLM(reqBody, apiKey, baseUrl);
if (reqBody.isAzure || serverConfig.isAzure) {
llm = new ChatOpenAI({
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: reqBody.isAzure
? reqBody.azureApiVersion
: serverConfig.azureApiVersion,
azureOpenAIApiDeploymentName: reqBody.model,
azureOpenAIBasePath: baseUrl,
});
}
const memory = new BufferMemory({
memoryKey: "history",
inputKey: "question",
outputKey: "answer",
returnMessages: true,
chatHistory: new ChatMessageHistory(pastMessages),
});
const MEMORY_KEY = "chat_history";
const prompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder(MEMORY_KEY),
new MessagesPlaceholder("input"),
new MessagesPlaceholder("agent_scratchpad"),
]);
const modelWithTools = llm.bind({
tools: tools.map(convertToOpenAITool),
});
const runnableAgent = RunnableSequence.from([
{
input: (i) => i.input,
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => {
return formatToOpenAIToolMessages(i.steps);
},
chat_history: async (i: {
input: string;
steps: ToolsAgentStep[];
}) => {
const { history } = await memory.loadMemoryVariables({});
return history;
},
},
prompt,
modelWithTools,
new OpenAIToolsAgentOutputParser(),
]).withConfig({ runName: "OpenAIToolsAgent" });
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
tools,
});
const lastMessageContent = reqBody.messages.slice(-1)[0].content;
const lastHumanMessage =
typeof lastMessageContent === "string"
? new HumanMessage(lastMessageContent)
: new HumanMessage({ content: lastMessageContent });
executor
const agent = await createToolCallingAgent({
llm,
tools,
prompt,
});
const agentExecutor = new AgentExecutor({
agent,
tools,
});
await agentExecutor
.invoke(
{
input: [lastHumanMessage],
input: lastMessageContent,
chat_history: pastMessages,
signal: this.controller.signal,
},
{
callbacks: [handler],
},
{ callbacks: [handler] },
)
.catch((error) => {
if (this.controller.signal.aborted) {

View File

@@ -3,7 +3,7 @@ import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ModelProvider } from "@/app/constant";
import { OpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
@@ -30,7 +30,7 @@ async function handle(req: NextRequest) {
const apiKey = await agentApi.getOpenAIApiKey(token);
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
const model = new OpenAI(
const model = new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
@@ -97,4 +97,4 @@ export const preferredRegion = [
"sfo1",
"sin1",
"syd1",
];
];

View File

@@ -3,7 +3,7 @@ import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
import { ModelProvider } from "@/app/constant";
import { OpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { Embeddings } from "langchain/dist/embeddings/base";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
@@ -32,7 +32,7 @@ async function handle(req: NextRequest) {
const apiKey = await agentApi.getOpenAIApiKey(token);
const baseUrl = await agentApi.getOpenAIBaseUrl(reqBody.baseUrl);
const model = new OpenAI(
const model = new ChatOpenAI(
{
temperature: 0,
modelName: reqBody.model,
@@ -117,4 +117,4 @@ export const preferredRegion = [
"sfo1",
"sin1",
"syd1",
];
];