feat: optimize rag

This commit is contained in:
Hk-Gosuto
2024-07-07 15:41:58 +08:00
parent f260f11755
commit 712022d8c7
19 changed files with 332 additions and 176 deletions

View File

@@ -0,0 +1,78 @@
import { Tool } from "@langchain/core/tools";
import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { formatDocumentsAsString } from "langchain/util/document";
import { Embeddings } from "langchain/dist/embeddings/base.js";
import { getServerSideConfig } from "@/app/config/server";
import { SupabaseVectorStore } from "@langchain/community/vectorstores/supabase";
import { createClient } from "@supabase/supabase-js";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
export class MyFilesBrowser extends StructuredTool {
static lc_name() {
return "MyFilesBrowser";
}
get lc_namespace() {
return [...super.lc_namespace, "myfilesbrowser"];
}
private sessionId: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
constructor(
sessionId: string,
model: BaseLanguageModel,
embeddings: Embeddings,
) {
super();
this.sessionId = sessionId;
this.model = model;
this.embeddings = embeddings;
}
schema = z.object({
queries: z.array(z.string()).describe("A query list."),
});
/** @ignore */
async _call({ queries }: z.infer<typeof this.schema>) {
const serverConfig = getServerSideConfig();
if (!serverConfig.isEnableRAG)
throw new Error("env ENABLE_RAG not configured");
const privateKey = process.env.SUPABASE_PRIVATE_KEY;
if (!privateKey) throw new Error(`Expected env var SUPABASE_PRIVATE_KEY`);
const url = process.env.SUPABASE_URL;
if (!url) throw new Error(`Expected env var SUPABASE_URL`);
const client = createClient(url, privateKey);
const vectorStore = new SupabaseVectorStore(this.embeddings, {
client,
tableName: "documents",
queryName: "match_documents",
});
let context;
const returnCunt = serverConfig.ragReturnCount
? parseInt(serverConfig.ragReturnCount, 10)
: 4;
console.log("[myfiles_browser]", { queries, returnCunt });
let documents: any[] = [];
for (var i = 0; i < queries.length; i++) {
let results = await vectorStore.similaritySearch(queries[i], returnCunt, {
sessionId: this.sessionId,
});
results.forEach((item) => documents.push(item));
}
context = formatDocumentsAsString(documents);
console.log("[myfiles_browser]", { context });
return context;
}
name = "myfiles_browser";
description = `queries to a search over the file(s) uploaded in the current conversation and displays the results.`;
}

View File

@@ -10,7 +10,7 @@ import { WolframAlphaTool } from "@/app/api/langchain-tools/wolframalpha";
import { BilibiliVideoInfoTool } from "./bilibili_vid_info";
import { BilibiliVideoSearchTool } from "./bilibili_vid_search";
import { BilibiliMusicRecognitionTool } from "./bilibili_music_recognition";
import { RAGSearch } from "./rag_search";
import { MyFilesBrowser } from "./myfiles_browser";
import { BilibiliVideoConclusionTool } from "./bilibili_vid_conclusion";
export class NodeJSTool {
@@ -59,7 +59,7 @@ export class NodeJSTool {
const bilibiliVideoSearchTool = new BilibiliVideoSearchTool();
const bilibiliVideoConclusionTool = new BilibiliVideoConclusionTool();
const bilibiliMusicRecognitionTool = new BilibiliMusicRecognitionTool();
let tools = [
let tools: any = [
calculatorTool,
webBrowserTool,
dallEAPITool,
@@ -73,7 +73,9 @@ export class NodeJSTool {
bilibiliVideoConclusionTool,
];
if (!!process.env.ENABLE_RAG) {
tools.push(new RAGSearch(this.sessionId, this.model, this.ragEmbeddings));
tools.push(
new MyFilesBrowser(this.sessionId, this.model, this.ragEmbeddings),
);
}
return tools;
}

View File

@@ -1,79 +0,0 @@
import { Tool } from "@langchain/core/tools";
import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { formatDocumentsAsString } from "langchain/util/document";
import { Embeddings } from "langchain/dist/embeddings/base.js";
import { RunnableSequence } from "@langchain/core/runnables";
import { StringOutputParser } from "@langchain/core/output_parsers";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "@langchain/pinecone";
import { getServerSideConfig } from "@/app/config/server";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
export class RAGSearch extends Tool {
static lc_name() {
return "RAGSearch";
}
get lc_namespace() {
return [...super.lc_namespace, "ragsearch"];
}
private sessionId: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
constructor(
sessionId: string,
model: BaseLanguageModel,
embeddings: Embeddings,
) {
super();
this.sessionId = sessionId;
this.model = model;
this.embeddings = embeddings;
}
/** @ignore */
async _call(inputs: string, runManager?: CallbackManagerForToolRun) {
const serverConfig = getServerSideConfig();
if (!serverConfig.isEnableRAG)
throw new Error("env ENABLE_RAG not configured");
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
// const vectorStore = await PineconeStore.fromExistingIndex(this.embeddings, {
// pineconeIndex,
// });
const vectorStore = await QdrantVectorStore.fromExistingCollection(
this.embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: this.sessionId,
},
);
let context;
const returnCunt = serverConfig.ragReturnCount
? parseInt(serverConfig.ragReturnCount, 10)
: 4;
console.log("[rag-search]", { inputs, returnCunt });
// const results = await vectorStore.similaritySearch(inputs, returnCunt, {
// sessionId: this.sessionId,
// });
const results = await vectorStore.similaritySearch(inputs, returnCunt);
context = formatDocumentsAsString(results);
console.log("[rag-search]", { context });
return context;
// const input = `Text:${context}\n\nQuestion:${inputs}\n\nI need you to answer the question based on the text.`;
// console.log("[rag-search]", input);
// const chain = RunnableSequence.from([this.model, new StringOutputParser()]);
// return chain.invoke(input, runManager?.getChild());
}
name = "rag-search";
description = `It is used to query documents entered by the user.The input content is the keywords extracted from the user's question, and multiple keywords are separated by spaces and passed in.`;
}

View File

@@ -20,7 +20,10 @@ import { FileInfo } from "@/app/client/platforms/utils";
import mime from "mime";
import LocalFileStorage from "@/app/utils/local_file_storage";
import S3FileStorage from "@/app/utils/s3_file_storage";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
import { SupabaseVectorStore } from "@langchain/community/vectorstores/supabase";
import { createClient } from "@supabase/supabase-js";
import { Embeddings } from "langchain/dist/embeddings/base";
interface RequestBody {
sessionId: string;
@@ -67,6 +70,11 @@ async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const privateKey = process.env.SUPABASE_PRIVATE_KEY;
if (!privateKey) throw new Error(`Expected env var SUPABASE_PRIVATE_KEY`);
const url = process.env.SUPABASE_URL;
if (!url) throw new Error(`Expected env var SUPABASE_URL`);
try {
const authResult = auth(req, ModelProvider.GPT);
if (authResult.error) {
@@ -81,18 +89,25 @@ async function handle(req: NextRequest) {
const apiKey = getOpenAIApiKey(token);
const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl);
const serverConfig = getServerSideConfig();
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
const embeddings = new OpenAIEmbeddings(
{
modelName: serverConfig.ragEmbeddingModel,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
let embeddings: Embeddings;
if (process.env.OLLAMA_BASE_URL) {
embeddings = new OllamaEmbeddings({
model: serverConfig.ragEmbeddingModel,
baseUrl: process.env.OLLAMA_BASE_URL,
});
} else {
embeddings = new OpenAIEmbeddings(
{
modelName: serverConfig.ragEmbeddingModel,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
// https://js.langchain.com/docs/integrations/vectorstores/pinecone
// https://js.langchain.com/docs/integrations/vectorstores/qdrant
// process files
let partial = "";
for (let i = 0; i < reqBody.fileInfos.length; i++) {
const fileInfo = reqBody.fileInfos[i];
const contentType = mime.getType(fileInfo.fileName);
@@ -134,26 +149,25 @@ async function handle(req: NextRequest) {
chunkOverlap: chunkOverlap,
});
const splits = await textSplitter.splitDocuments(docs);
const vectorStore = await QdrantVectorStore.fromDocuments(
const client = createClient(url, privateKey);
const vectorStore = await SupabaseVectorStore.fromDocuments(
splits,
embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: reqBody.sessionId,
client,
tableName: "documents",
queryName: "match_documents",
},
);
// await PineconeStore.fromDocuments(splits, embeddings, {
// pineconeIndex,
// maxConcurrency: 5,
// });
// const vectorStore = await PineconeStore.fromExistingIndex(embeddings, {
// pineconeIndex,
// });
partial = splits
.slice(0, 2)
.map((v) => v.pageContent)
.join("\n");
}
return NextResponse.json(
{
sessionId: reqBody.sessionId,
partial: partial,
},
{
status: 200,

View File

@@ -4,6 +4,8 @@ import { auth } from "@/app/api/auth";
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
import { ModelProvider } from "@/app/constant";
import { OpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { Embeddings } from "langchain/dist/embeddings/base";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
@@ -44,13 +46,22 @@ async function handle(req: NextRequest) {
},
{ basePath: baseUrl },
);
const ragEmbeddings = new OpenAIEmbeddings(
{
modelName: process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large",
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
let ragEmbeddings: Embeddings;
if (process.env.OLLAMA_BASE_URL) {
ragEmbeddings = new OllamaEmbeddings({
model: process.env.RAG_EMBEDDING_MODEL,
baseUrl: process.env.OLLAMA_BASE_URL,
});
} else {
ragEmbeddings = new OpenAIEmbeddings(
{
modelName:
process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large",
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
var dalleCallback = async (data: string) => {
var response = new ResponseBody();

View File

@@ -116,7 +116,7 @@ export abstract class LLMApi {
abstract speech(options: SpeechOptions): Promise<ArrayBuffer>;
abstract transcription(options: TranscriptionOptions): Promise<string>;
abstract toolAgentChat(options: AgentChatOptions): Promise<void>;
abstract createRAGStore(options: CreateRAGStoreOptions): Promise<void>;
abstract createRAGStore(options: CreateRAGStoreOptions): Promise<string>;
abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>;
}

View File

@@ -89,7 +89,7 @@ export class ClaudeApi implements LLMApi {
toolAgentChat(options: AgentChatOptions): Promise<void> {
throw new Error("Method not implemented.");
}
createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
createRAGStore(options: CreateRAGStoreOptions): Promise<string> {
throw new Error("Method not implemented.");
}
extractMessage(res: any) {

View File

@@ -20,7 +20,7 @@ import {
} from "@/app/utils";
export class GeminiProApi implements LLMApi {
createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
createRAGStore(options: CreateRAGStoreOptions): Promise<string> {
throw new Error("Method not implemented.");
}
transcription(options: TranscriptionOptions): Promise<string> {

View File

@@ -373,7 +373,7 @@ export class ChatGPTApi implements LLMApi {
}
}
async createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
async createRAGStore(options: CreateRAGStoreOptions): Promise<string> {
try {
const accessStore = useAccessStore.getState();
const isAzure = accessStore.provider === ServiceProvider.Azure;
@@ -395,9 +395,12 @@ export class ChatGPTApi implements LLMApi {
};
const res = await fetch(path, chatPayload);
if (res.status !== 200) throw new Error(await res.text());
const resJson = await res.json();
return resJson.partial;
} catch (e) {
console.log("[Request] failed to make a chat reqeust", e);
options.onError?.(e as Error);
return "";
}
}

View File

@@ -1,10 +1,13 @@
import { getHeaders } from "../api";
import { getClientApi } from "@/app/utils";
import { ClientApi, getHeaders } from "../api";
import { ChatSession } from "@/app/store";
export interface FileInfo {
originalFilename: string;
fileName: string;
filePath: string;
size: number;
partial?: string;
}
export class FileApi {
@@ -31,4 +34,15 @@ export class FileApi {
filePath: resJson.filePath,
};
}
async uploadForRag(file: any, session: ChatSession): Promise<FileInfo> {
var fileInfo = await this.upload(file);
var api: ClientApi = getClientApi(session.mask.modelConfig.model);
let partial = await api.llm.createRAGStore({
chatSessionId: session.id,
fileInfos: [fileInfo],
});
fileInfo.partial = partial;
return fileInfo;
}
}

View File

@@ -1398,32 +1398,25 @@ function _Chat() {
const fileInput = document.createElement("input");
fileInput.type = "file";
fileInput.accept = ".pdf,.txt,.md,.json,.csv,.docx,.srt,.mp3";
fileInput.multiple = true;
fileInput.multiple = false;
fileInput.onchange = (event: any) => {
setUploading(true);
const files = event.target.files;
const file = event.target.files[0];
const api = new ClientApi();
const fileDatas: FileInfo[] = [];
for (let i = 0; i < files.length; i++) {
const file = event.target.files[i];
api.file
.upload(file)
.then((fileInfo) => {
console.log(fileInfo);
fileDatas.push(fileInfo);
if (
fileDatas.length === 3 ||
fileDatas.length === files.length
) {
setUploading(false);
res(fileDatas);
}
})
.catch((e) => {
setUploading(false);
rej(e);
});
}
api.file
.uploadForRag(file, session)
.then((fileInfo) => {
console.log(fileInfo);
fileDatas.push(fileInfo);
session.attachFiles.push(fileInfo);
setUploading(false);
res(fileDatas);
})
.catch((e) => {
setUploading(false);
rej(e);
});
};
fileInput.click();
})),
@@ -1694,7 +1687,7 @@ function _Chat() {
parentRef={scrollRef}
defaultShow={i >= messages.length - 6}
/>
{message.fileInfos && message.fileInfos.length > 0 && (
{/* {message.fileInfos && message.fileInfos.length > 0 && (
<nav
className={styles["chat-message-item-files"]}
style={
@@ -1716,7 +1709,7 @@ function _Chat() {
);
})}
</nav>
)}
)} */}
{getMessageImages(message).length == 1 && (
<img
className={styles["chat-message-item-image"]}

View File

@@ -236,3 +236,30 @@ export const internalAllowedWebDavEndpoints = [
"https://webdav.yandex.com",
"https://app.koofr.net/dav/Koofr",
];
export const MYFILES_BROWSER_TOOLS_SYSTEM_PROMPT = `
# Tools
## myfiles_browser
You have the tool 'myfiles_browser' with the following functions:
issues queries to search the file(s) uploaded in the current conversation and displays the results.
This tool is for browsing the files uploaded by the user.
Parts of the documents uploaded by users will be automatically included in the conversation. Only use this tool when the relevant parts don't contain the necessary information to fulfill the user's request.
If the user needs to summarize the document, they can summarize it through parts of the document.
Think carefully about how the information you find relates to the user's request. Respond as soon as you find information that clearly answers the request.
Issue multiple queries to the 'myfiles_browser' command only when the user's question needs to be decomposed to find different facts. In other scenarios, prefer providing a single query. Avoid single-word queries that are extremely broad and will return unrelated results.
Here are some examples of how to use the 'myfiles_browser' command:
User: What was the GDP of France and Italy in the 1970s? => myfiles_browser(["france gdp 1970", "italy gdp 1970"])
User: What does the report say about the GPT4 performance on MMLU? => myfiles_browser(["GPT4 MMLU performance"])
User: How can I integrate customer relationship management system with third-party email marketing tools? => myfiles_browser(["customer management system marketing integration"])
User: What are the best practices for data security and privacy for our cloud storage services? => myfiles_browser(["cloud storage security and privacy"])
The user has uploaded the following files:
`;

View File

@@ -13,6 +13,7 @@ import {
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
MYFILES_BROWSER_TOOLS_SYSTEM_PROMPT,
} from "../constant";
import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
import { ChatControllerPool } from "../client/controller";
@@ -69,6 +70,8 @@ export interface ChatSession {
clearContextIndex?: number;
mask: Mask;
attachFiles: FileInfo[];
}
export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
@@ -92,6 +95,8 @@ function createEmptySession(): ChatSession {
lastSummarizeIndex: 0,
mask: createEmptyMask(),
attachFiles: [],
};
}
@@ -354,6 +359,10 @@ export const useChatStore = createPersistStore(
}),
);
}
// add file link
if (attachFiles && attachFiles.length > 0) {
mContent += ` [${attachFiles[0].originalFilename}](${attachFiles[0].filePath})`;
}
let userMessage: ChatMessage = createMessage({
role: "user",
content: mContent,
@@ -365,7 +374,9 @@ export const useChatStore = createPersistStore(
model: modelConfig.model,
toolMessages: [],
});
var api: ClientApi = getClientApi(modelConfig.model);
const isEnableRAG =
session.attachFiles && session.attachFiles.length > 0;
// get recent messages
const recentMessages = get().getMessagesWithMemory();
const sendMessages = recentMessages.concat(userMessage);
@@ -391,8 +402,6 @@ export const useChatStore = createPersistStore(
session.messages.push(savedUserMessage);
session.messages.push(botMessage);
});
const isEnableRAG = attachFiles && attachFiles?.length > 0;
var api: ClientApi = getClientApi(modelConfig.model);
if (
config.pluginConfig.enable &&
session.mask.usePlugins &&
@@ -401,8 +410,13 @@ export const useChatStore = createPersistStore(
modelConfig.model != "gpt-4-vision-preview"
) {
console.log("[ToolAgent] start");
const pluginToolNames = allPlugins.map((m) => m.toolName);
if (isEnableRAG) pluginToolNames.push("rag-search");
let pluginToolNames = allPlugins.map((m) => m.toolName);
if (isEnableRAG) {
// other plugins will affect rag
// clear existing plugins here
pluginToolNames = [];
pluginToolNames.push("myfiles_browser");
}
const agentCall = () => {
api.llm.toolAgentChat({
chatSessionId: session.id,
@@ -469,19 +483,7 @@ export const useChatStore = createPersistStore(
},
});
};
if (attachFiles && attachFiles.length > 0) {
await api.llm
.createRAGStore({
chatSessionId: session.id,
fileInfos: attachFiles,
})
.then(() => {
console.log("[RAG]", "Vector db created");
agentCall();
});
} else {
agentCall();
}
agentCall();
} else {
// make request
api.llm.chat({
@@ -565,13 +567,23 @@ export const useChatStore = createPersistStore(
session.mask.modelConfig.model.startsWith("gpt-");
var systemPrompts: ChatMessage[] = [];
var template = DEFAULT_SYSTEM_TEMPLATE;
if (session.attachFiles && session.attachFiles.length > 0) {
template += MYFILES_BROWSER_TOOLS_SYSTEM_PROMPT;
session.attachFiles.forEach((file) => {
template += `filename: \`${file.originalFilename}\`
partialDocument: \`\`\`
${file.partial}
\`\`\``;
});
}
systemPrompts = shouldInjectSystemPrompts
? [
createMessage({
role: "system",
content: fillTemplateWith("", {
...modelConfig,
template: DEFAULT_SYSTEM_TEMPLATE,
template: template,
}),
}),
]