feat: optimize rag

This commit is contained in:
Hk-Gosuto
2024-07-07 15:41:58 +08:00
parent f260f11755
commit 712022d8c7
19 changed files with 332 additions and 176 deletions

View File

@@ -20,7 +20,10 @@ import { FileInfo } from "@/app/client/platforms/utils";
import mime from "mime";
import LocalFileStorage from "@/app/utils/local_file_storage";
import S3FileStorage from "@/app/utils/s3_file_storage";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
import { OllamaEmbeddings } from "@langchain/community/embeddings/ollama";
import { SupabaseVectorStore } from "@langchain/community/vectorstores/supabase";
import { createClient } from "@supabase/supabase-js";
import { Embeddings } from "langchain/dist/embeddings/base";
interface RequestBody {
sessionId: string;
@@ -67,6 +70,11 @@ async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const privateKey = process.env.SUPABASE_PRIVATE_KEY;
if (!privateKey) throw new Error(`Expected env var SUPABASE_PRIVATE_KEY`);
const url = process.env.SUPABASE_URL;
if (!url) throw new Error(`Expected env var SUPABASE_URL`);
try {
const authResult = auth(req, ModelProvider.GPT);
if (authResult.error) {
@@ -81,18 +89,25 @@ async function handle(req: NextRequest) {
const apiKey = getOpenAIApiKey(token);
const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl);
const serverConfig = getServerSideConfig();
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
const embeddings = new OpenAIEmbeddings(
{
modelName: serverConfig.ragEmbeddingModel,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
let embeddings: Embeddings;
if (process.env.OLLAMA_BASE_URL) {
embeddings = new OllamaEmbeddings({
model: serverConfig.ragEmbeddingModel,
baseUrl: process.env.OLLAMA_BASE_URL,
});
} else {
embeddings = new OpenAIEmbeddings(
{
modelName: serverConfig.ragEmbeddingModel,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
}
// https://js.langchain.com/docs/integrations/vectorstores/pinecone
// https://js.langchain.com/docs/integrations/vectorstores/qdrant
// process files
let partial = "";
for (let i = 0; i < reqBody.fileInfos.length; i++) {
const fileInfo = reqBody.fileInfos[i];
const contentType = mime.getType(fileInfo.fileName);
@@ -134,26 +149,25 @@ async function handle(req: NextRequest) {
chunkOverlap: chunkOverlap,
});
const splits = await textSplitter.splitDocuments(docs);
const vectorStore = await QdrantVectorStore.fromDocuments(
const client = createClient(url, privateKey);
const vectorStore = await SupabaseVectorStore.fromDocuments(
splits,
embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: reqBody.sessionId,
client,
tableName: "documents",
queryName: "match_documents",
},
);
// await PineconeStore.fromDocuments(splits, embeddings, {
// pineconeIndex,
// maxConcurrency: 5,
// });
// const vectorStore = await PineconeStore.fromExistingIndex(embeddings, {
// pineconeIndex,
// });
partial = splits
.slice(0, 2)
.map((v) => v.pageContent)
.join("\n");
}
return NextResponse.json(
{
sessionId: reqBody.sessionId,
partial: partial,
},
{
status: 200,