Merge branch 'feature-openai-tool-call'

This commit is contained in:
Hk-Gosuto
2023-12-29 18:23:39 +08:00
4 changed files with 97 additions and 82 deletions

View File

@@ -38,37 +38,6 @@ export class EdgeTool {
}
async getCustomTools(): Promise<any[]> {
// let searchTool: Tool = new DuckDuckGo();
// if (process.env.CHOOSE_SEARCH_ENGINE) {
// switch (process.env.CHOOSE_SEARCH_ENGINE) {
// case "google":
// searchTool = new GoogleSearch();
// break;
// case "baidu":
// searchTool = new BaiduSearch();
// break;
// }
// }
// if (process.env.BING_SEARCH_API_KEY) {
// let bingSearchTool = new langchainTools["BingSerpAPI"](
// process.env.BING_SEARCH_API_KEY,
// );
// searchTool = new DynamicTool({
// name: "bing_search",
// description: bingSearchTool.description,
// func: async (input: string) => bingSearchTool.call(input),
// });
// }
// if (process.env.SERPAPI_API_KEY) {
// let serpAPITool = new langchainTools["SerpAPI"](
// process.env.SERPAPI_API_KEY,
// );
// searchTool = new DynamicTool({
// name: "google_search",
// description: serpAPITool.description,
// func: async (input: string) => serpAPITool.call(input),
// });
// }
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
@@ -79,7 +48,6 @@ export class EdgeTool {
this.baseUrl,
this.callback,
);
dallEAPITool.returnDirect = true;
const stableDiffusionTool = new StableDiffusionWrapper();
const arxivAPITool = new ArxivAPIWrapper();
return [

View File

@@ -7,7 +7,10 @@ import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents";
import {
AgentExecutor,
initializeAgentExecutorWithOptions,
} from "langchain/agents";
import { ACCESS_CODE_PREFIX, ServiceProvider } from "@/app/constant";
import * as langchainTools from "langchain/tools";
@@ -17,6 +20,14 @@ import { DynamicTool, Tool } from "langchain/tools";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { useAccessStore } from "@/app/store";
import { DynamicStructuredTool, formatToOpenAITool } from "langchain/tools";
import { formatToOpenAIToolMessages } from "langchain/agents/format_scratchpad/openai_tools";
import {
OpenAIToolsAgentOutputParser,
type ToolsAgentStep,
} from "langchain/agents/openai/output_parser";
import { RunnableSequence } from "langchain/schema/runnable";
import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
export interface RequestMessage {
role: string;
@@ -92,9 +103,9 @@ export class AgentApi {
await writer.close();
},
async handleChainEnd(outputs, runId, parentRunId, tags) {
console.log("[handleChainEnd]");
await writer.ready;
await writer.close();
// console.log("[handleChainEnd]");
// await writer.ready;
// await writer.close();
},
async handleLLMEnd() {
// await writer.ready;
@@ -111,10 +122,10 @@ export class AgentApi {
);
await writer.close();
},
handleLLMStart(llm, _prompts: string[]) {
async handleLLMStart(llm, _prompts: string[]) {
// console.log("handleLLMStart: I'm the second handler!!", { llm });
},
handleChainStart(chain) {
async handleChainStart(chain) {
// console.log("handleChainStart: I'm the second handler!!", { chain });
},
async handleAgentAction(action) {
@@ -141,14 +152,16 @@ export class AgentApi {
await writer.close();
}
},
handleToolStart(tool, input) {
async handleToolStart(tool, input) {
// console.log("[handleToolStart]", { tool });
},
async handleToolEnd(output, runId, parentRunId, tags) {
// console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
handleAgentEnd(action, runId, parentRunId, tags) {
// console.log("[handleAgentEnd]");
async handleAgentEnd(action, runId, parentRunId, tags) {
console.log("[handleAgentEnd]");
await writer.ready;
await writer.close();
},
});
}
@@ -243,7 +256,6 @@ export class AgentApi {
});
}
if (process.env.GOOGLE_CSE_ID && process.env.GOOGLE_API_KEY) {
console.log("use googleCustomSearchTool");
let googleCustomSearchTool = new langchainTools["GoogleCustomSearch"]();
searchTool = new DynamicTool({
name: "google_custom_search",
@@ -289,13 +301,13 @@ export class AgentApi {
pastMessages.push(new AIMessage(message.content));
});
const memory = new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
inputKey: "input",
outputKey: "output",
chatHistory: new ChatMessageHistory(pastMessages),
});
// const memory = new BufferMemory({
// memoryKey: "chat_history",
// returnMessages: true,
// inputKey: "input",
// outputKey: "output",
// chatHistory: new ChatMessageHistory(pastMessages),
// });
let llm = new ChatOpenAI(
{
@@ -325,13 +337,48 @@ export class AgentApi {
azureOpenAIBasePath: baseUrl,
});
}
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps,
maxIterations: reqBody.maxIterations,
memory: memory,
const memory = new BufferMemory({
memoryKey: "history",
inputKey: "question",
outputKey: "answer",
returnMessages: true,
chatHistory: new ChatMessageHistory(pastMessages),
});
const prompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder("chat_history"),
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);
const modelWithTools = llm.bind({ tools: tools.map(formatToOpenAITool) });
const runnableAgent = RunnableSequence.from([
{
input: (i: { input: string; steps: ToolsAgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) =>
formatToOpenAIToolMessages(i.steps),
chat_history: async (_: {
input: string;
steps: ToolsAgentStep[];
}) => {
const { history } = await memory.loadMemoryVariables({});
return history;
},
},
prompt,
modelWithTools,
new OpenAIToolsAgentOutputParser(),
]).withConfig({ runName: "OpenAIToolsAgent" });
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
tools,
});
// const executor = await initializeAgentExecutorWithOptions(tools, llm, {
// agentType: "openai-functions",
// returnIntermediateSteps: reqBody.returnIntermediateSteps,
// maxIterations: reqBody.maxIterations,
// memory: memory,
// });
executor.call(
{