feat: support tool call

This commit is contained in:
Hk-Gosuto
2023-12-29 09:43:37 +08:00
parent d050fe636f
commit 8cdbc231ca
4 changed files with 99 additions and 50 deletions

View File

@@ -7,7 +7,10 @@ import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents";
import {
AgentExecutor,
initializeAgentExecutorWithOptions,
} from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import * as langchainTools from "langchain/tools";
@@ -17,6 +20,14 @@ import { DynamicTool, Tool } from "langchain/tools";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { useAccessStore } from "@/app/store";
import { DynamicStructuredTool, formatToOpenAITool } from "langchain/tools";
import { formatToOpenAIToolMessages } from "langchain/agents/format_scratchpad/openai_tools";
import {
OpenAIToolsAgentOutputParser,
type ToolsAgentStep,
} from "langchain/agents/openai/output_parser";
import { RunnableSequence } from "langchain/schema/runnable";
import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
export interface RequestMessage {
role: string;
@@ -92,9 +103,9 @@ export class AgentApi {
await writer.close();
},
async handleChainEnd(outputs, runId, parentRunId, tags) {
console.log("[handleChainEnd]");
await writer.ready;
await writer.close();
// console.log("[handleChainEnd]");
// await writer.ready;
// await writer.close();
},
async handleLLMEnd() {
// await writer.ready;
@@ -111,10 +122,10 @@ export class AgentApi {
);
await writer.close();
},
handleLLMStart(llm, _prompts: string[]) {
async handleLLMStart(llm, _prompts: string[]) {
// console.log("handleLLMStart: I'm the second handler!!", { llm });
},
handleChainStart(chain) {
async handleChainStart(chain) {
// console.log("handleChainStart: I'm the second handler!!", { chain });
},
async handleAgentAction(action) {
@@ -141,14 +152,16 @@ export class AgentApi {
await writer.close();
}
},
handleToolStart(tool, input) {
async handleToolStart(tool, input) {
// console.log("[handleToolStart]", { tool });
},
async handleToolEnd(output, runId, parentRunId, tags) {
// console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
handleAgentEnd(action, runId, parentRunId, tags) {
// console.log("[handleAgentEnd]");
async handleAgentEnd(action, runId, parentRunId, tags) {
console.log("[handleAgentEnd]");
await writer.ready;
await writer.close();
},
});
}
@@ -288,13 +301,13 @@ export class AgentApi {
pastMessages.push(new AIMessage(message.content));
});
const memory = new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
inputKey: "input",
outputKey: "output",
chatHistory: new ChatMessageHistory(pastMessages),
});
// const memory = new BufferMemory({
// memoryKey: "chat_history",
// returnMessages: true,
// inputKey: "input",
// outputKey: "output",
// chatHistory: new ChatMessageHistory(pastMessages),
// });
let llm = new ChatOpenAI(
{
@@ -324,13 +337,48 @@ export class AgentApi {
azureOpenAIBasePath: baseUrl,
});
}
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps,
maxIterations: reqBody.maxIterations,
memory: memory,
const memory = new BufferMemory({
memoryKey: "history",
inputKey: "question",
outputKey: "answer",
returnMessages: true,
chatHistory: new ChatMessageHistory(pastMessages),
});
const prompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder("chat_history"),
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);
const modelWithTools = llm.bind({ tools: tools.map(formatToOpenAITool) });
const runnableAgent = RunnableSequence.from([
{
input: (i: { input: string; steps: ToolsAgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) =>
formatToOpenAIToolMessages(i.steps),
chat_history: async (_: {
input: string;
steps: ToolsAgentStep[];
}) => {
const { history } = await memory.loadMemoryVariables({});
return history;
},
},
prompt,
modelWithTools,
new OpenAIToolsAgentOutputParser(),
]).withConfig({ runName: "OpenAIToolsAgent" });
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
tools,
});
// const executor = await initializeAgentExecutorWithOptions(tools, llm, {
// agentType: "openai-functions",
// returnIntermediateSteps: reqBody.returnIntermediateSteps,
// maxIterations: reqBody.maxIterations,
// memory: memory,
// });
executor.call(
{