feat: optimize rag

This commit is contained in:
Hk-Gosuto
2024-07-07 15:41:58 +08:00
parent f260f11755
commit 712022d8c7
19 changed files with 332 additions and 176 deletions

View File

@@ -0,0 +1,78 @@
import { Tool } from "@langchain/core/tools";
import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { formatDocumentsAsString } from "langchain/util/document";
import { Embeddings } from "langchain/dist/embeddings/base.js";
import { getServerSideConfig } from "@/app/config/server";
import { SupabaseVectorStore } from "@langchain/community/vectorstores/supabase";
import { createClient } from "@supabase/supabase-js";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
export class MyFilesBrowser extends StructuredTool {
static lc_name() {
return "MyFilesBrowser";
}
get lc_namespace() {
return [...super.lc_namespace, "myfilesbrowser"];
}
private sessionId: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
constructor(
sessionId: string,
model: BaseLanguageModel,
embeddings: Embeddings,
) {
super();
this.sessionId = sessionId;
this.model = model;
this.embeddings = embeddings;
}
schema = z.object({
queries: z.array(z.string()).describe("A query list."),
});
/** @ignore */
async _call({ queries }: z.infer<typeof this.schema>) {
const serverConfig = getServerSideConfig();
if (!serverConfig.isEnableRAG)
throw new Error("env ENABLE_RAG not configured");
const privateKey = process.env.SUPABASE_PRIVATE_KEY;
if (!privateKey) throw new Error(`Expected env var SUPABASE_PRIVATE_KEY`);
const url = process.env.SUPABASE_URL;
if (!url) throw new Error(`Expected env var SUPABASE_URL`);
const client = createClient(url, privateKey);
const vectorStore = new SupabaseVectorStore(this.embeddings, {
client,
tableName: "documents",
queryName: "match_documents",
});
let context;
const returnCunt = serverConfig.ragReturnCount
? parseInt(serverConfig.ragReturnCount, 10)
: 4;
console.log("[myfiles_browser]", { queries, returnCunt });
let documents: any[] = [];
for (var i = 0; i < queries.length; i++) {
let results = await vectorStore.similaritySearch(queries[i], returnCunt, {
sessionId: this.sessionId,
});
results.forEach((item) => documents.push(item));
}
context = formatDocumentsAsString(documents);
console.log("[myfiles_browser]", { context });
return context;
}
name = "myfiles_browser";
description = `queries to a search over the file(s) uploaded in the current conversation and displays the results.`;
}

View File

@@ -10,7 +10,7 @@ import { WolframAlphaTool } from "@/app/api/langchain-tools/wolframalpha";
import { BilibiliVideoInfoTool } from "./bilibili_vid_info";
import { BilibiliVideoSearchTool } from "./bilibili_vid_search";
import { BilibiliMusicRecognitionTool } from "./bilibili_music_recognition";
import { RAGSearch } from "./rag_search";
import { MyFilesBrowser } from "./myfiles_browser";
import { BilibiliVideoConclusionTool } from "./bilibili_vid_conclusion";
export class NodeJSTool {
@@ -59,7 +59,7 @@ export class NodeJSTool {
const bilibiliVideoSearchTool = new BilibiliVideoSearchTool();
const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool();
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
let tools = [
let tools: any = [
calculatorTool,
webBrowserTool,
dallEAPITool,
@@ -73,7 +73,9 @@ export class NodeJSTool {
bilibiliVideoConclusionTool,
];
if (!!process.env.ENABLE_RAG) {
tools.push(new RAGSearch(this.sessionId, this.model, this.ragEmbeddings));
tools.push(
new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings),
);
}
return tools;
}

View File

@@ -1,79 +0,0 @@
import { Tool } from "@langchain/core/tools";
import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { formatDocumentsAsString } from "langchain/util/document";
import { Embeddings } from "langchain/dist/embeddings/base.js";
import { RunnableSequence } from "@langchain/core/runnables";
import { StringOutputParser } from "@langchain/core/output_parsers";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "@langchain/pinecone";
import { getServerSideConfig } from "@/app/config/server";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
export class RAGSearch extends Tool {
static lc_name() {
return "RAGSearch";
}
get lc_namespace() {
return [...super.lc_namespace, "ragsearch"];
}
private sessionId: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
constructor(
sessionId: string,
model: BaseLanguageModel,
embeddings: Embeddings,
) {
super();
this.sessionId = sessionId;
this.model = model;
this.embeddings = embeddings;
}
/** @ignore */
async _call(inputs: string, runManager?: CallbackManagerForToolRun) {
const serverConfig = getServerSideConfig();
if (!serverConfig.isEnableRAG)
throw new Error("env ENABLE_RAG not configured");
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
// const vectorStore = await PineconeStore.fromExistingIndex(this.embeddings, {
// pineconeIndex,
// });
const vectorStore = await QdrantVectorStore.fromExistingCollection(
this.embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: this.sessionId,
},
);
let context;
const returnCunt = serverConfig.ragReturnCount
? parseInt(serverConfig.ragReturnCount, 10)
: 4;
console.log("[rag-search]", { inputs, returnCunt });
// const results = await vectorStore.similaritySearch(inputs, returnCunt, {
// sessionId: this.sessionId,
// });
const results = await vectorStore.similaritySearch(inputs, returnCunt);
context = formatDocumentsAsString(results);
console.log("[rag-search]", { context });
return context;
// const input = `Text:${context}\n\nQuestion:${inputs}\n\nI need you to answer the question based on the text.`;
// console.log("[rag-search]", input);
// const chain = RunnableSequence.from([this.model, new StringOutputParser()]);
// return chain.invoke(input, runManager?.getChild());
}
name = "rag-search";
description = `It is used to query documents entered by the user.The input content is the keywords extracted from the user's question, and multiple keywords are separated by spaces and passed in.`;
}