This commit is contained in:
Hk-Gosuto
2024-04-07 18:00:21 +08:00
parent 7382ce48bb
commit b00e9f0c79
17 changed files with 307 additions and 122 deletions

View File

@@ -7,6 +7,8 @@ import { RunnableSequence } from "@langchain/core/runnables";
import { StringOutputParser } from "@langchain/core/output_parsers";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "@langchain/pinecone";
import { getServerSideConfig } from "@/app/config/server";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
export class RAGSearch extends Tool {
static lc_name() {
@@ -34,21 +36,32 @@ export class RAGSearch extends Tool {
/** @ignore */
async _call(inputs: string, runManager?: CallbackManagerForToolRun) {
const pinecone = new Pinecone();
const pineconeIndex = pinecone.Index(process.env.PINECONE_INDEX!);
const vectorStore = await PineconeStore.fromExistingIndex(this.embeddings, {
pineconeIndex,
});
const serverConfig = getServerSideConfig();
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
// const vectorStore = await PineconeStore.fromExistingIndex(this.embeddings, {
// pineconeIndex,
// });
const vectorStore = await QdrantVectorStore.fromExistingCollection(
this.embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: this.sessionId,
},
);
let context;
const returnCunt = process.env.RAG_RETURN_COUNT
? parseInt(process.env.RAG_RETURN_COUNT, 10)
const returnCunt = serverConfig.ragReturnCount
? parseInt(serverConfig.ragReturnCount, 10)
: 4;
const results = await vectorStore.similaritySearch(inputs, returnCunt, {
sessionId: this.sessionId,
});
console.log("[rag-search]", { inputs, returnCunt });
// const results = await vectorStore.similaritySearch(inputs, returnCunt, {
// sessionId: this.sessionId,
// });
const results = await vectorStore.similaritySearch(inputs, returnCunt);
context = formatDocumentsAsString(results);
console.log("[rag-search]", context);
console.log("[rag-search]", { context });
return context;
// const input = `Text:${context}\n\nQuestion:${inputs}\n\nI need you to answer the question based on the text.`;

View File

@@ -4,6 +4,7 @@ import { ACCESS_CODE_PREFIX, ModelProvider } from "@/app/constant";
import { OpenAIEmbeddings } from "@langchain/openai";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "@langchain/pinecone";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
import { getServerSideConfig } from "@/app/config/server";
interface RequestBody {
@@ -27,26 +28,40 @@ async function handle(req: NextRequest) {
const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const pinecone = new Pinecone();
const pineconeIndex = pinecone.Index(process.env.PINECONE_INDEX!);
const serverConfig = getServerSideConfig();
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
const apiKey = getOpenAIApiKey(token);
const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl);
const embeddings = new OpenAIEmbeddings(
{
modelName: process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large",
modelName: serverConfig.ragEmbeddingModel ?? "text-embedding-3-large",
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
const vectorStore = await PineconeStore.fromExistingIndex(embeddings, {
pineconeIndex,
});
const results = await vectorStore.similaritySearch(reqBody.query, 1, {
sessionId: reqBody.sessionId,
});
console.log(results);
return NextResponse.json(results, {
// const vectorStore = await PineconeStore.fromExistingIndex(embeddings, {
// pineconeIndex,
// });
// const results = await vectorStore.similaritySearch(reqBody.query, 4, {
// sessionId: reqBody.sessionId,
// });
const vectorStore = await QdrantVectorStore.fromExistingCollection(
embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: reqBody.sessionId,
},
);
const returnCunt = serverConfig.ragReturnCount
? parseInt(serverConfig.ragReturnCount, 10)
: 4;
const response = await vectorStore.similaritySearch(
reqBody.query,
returnCunt,
);
return NextResponse.json(response, {
status: 200,
});
} catch (e) {

View File

@@ -20,6 +20,7 @@ import { FileInfo } from "@/app/client/platforms/utils";
import mime from "mime";
import LocalFileStorage from "@/app/utils/local_file_storage";
import S3FileStorage from "@/app/utils/s3_file_storage";
import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant";
interface RequestBody {
sessionId: string;
@@ -80,16 +81,17 @@ async function handle(req: NextRequest) {
const apiKey = getOpenAIApiKey(token);
const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl);
const serverConfig = getServerSideConfig();
const pinecone = new Pinecone();
const pineconeIndex = pinecone.Index(process.env.PINECONE_INDEX!);
// const pinecone = new Pinecone();
// const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!);
const embeddings = new OpenAIEmbeddings(
{
modelName: process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large",
modelName: serverConfig.ragEmbeddingModel,
openAIApiKey: apiKey,
},
{ basePath: baseUrl },
);
//https://js.langchain.com/docs/integrations/vectorstores/pinecone
// https://js.langchain.com/docs/integrations/vectorstores/pinecone
// https://js.langchain.com/docs/integrations/vectorstores/qdrant
// process files
for (let i = 0; i < reqBody.fileInfos.length; i++) {
const fileInfo = reqBody.fileInfos[i];
@@ -121,22 +123,33 @@ async function handle(req: NextRequest) {
};
});
// split
const chunkSize = process.env.RAG_CHUNK_SIZE
? parseInt(process.env.RAG_CHUNK_SIZE, 10)
const chunkSize = serverConfig.ragChunkSize
? parseInt(serverConfig.ragChunkSize, 10)
: 2000;
const chunkOverlap = process.env.RAG_CHUNK_OVERLAP
? parseInt(process.env.RAG_CHUNK_OVERLAP, 10)
const chunkOverlap = serverConfig.ragChunkOverlap
? parseInt(serverConfig.ragChunkOverlap, 10)
: 200;
const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: chunkSize,
chunkOverlap: chunkOverlap,
});
const splits = await textSplitter.splitDocuments(docs);
// remove history
await PineconeStore.fromDocuments(splits, embeddings, {
pineconeIndex,
maxConcurrency: 5,
});
const vectorStore = await QdrantVectorStore.fromDocuments(
splits,
embeddings,
{
url: process.env.QDRANT_URL,
apiKey: process.env.QDRANT_API_KEY,
collectionName: reqBody.sessionId,
},
);
// await PineconeStore.fromDocuments(splits, embeddings, {
// pineconeIndex,
// maxConcurrency: 5,
// });
// const vectorStore = await PineconeStore.fromExistingIndex(embeddings, {
// pineconeIndex,
// });
}
return NextResponse.json(
{

View File

@@ -115,7 +115,7 @@ export abstract class LLMApi {
abstract speech(options: SpeechOptions): Promise<ArrayBuffer>;
abstract transcription(options: TranscriptionOptions): Promise<string>;
abstract toolAgentChat(options: AgentChatOptions): Promise<void>;
abstract createRAGSore(options: CreateRAGStoreOptions): Promise<void>;
abstract createRAGStore(options: CreateRAGStoreOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>;
abstract models(): Promise<LLMModel[]>;
}

View File

@@ -20,7 +20,7 @@ import {
} from "@/app/utils";
export class GeminiProApi implements LLMApi {
createRAGSore(options: CreateRAGStoreOptions): Promise<void> {
createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
throw new Error("Method not implemented.");
}
transcription(options: TranscriptionOptions): Promise<string> {

View File

@@ -363,7 +363,7 @@ export class ChatGPTApi implements LLMApi {
}
}
async createRAGSore(options: CreateRAGStoreOptions): Promise<void> {
async createRAGStore(options: CreateRAGStoreOptions): Promise<void> {
try {
const accessStore = useAccessStore.getState();
const isAzure = accessStore.provider === ServiceProvider.Azure;
@@ -373,7 +373,7 @@ export class ChatGPTApi implements LLMApi {
fileInfos: options.fileInfos,
baseUrl: baseUrl,
};
console.log("[Request] openai payload: ", requestPayload);
console.log("[Request] rag store payload: ", requestPayload);
const controller = new AbortController();
options.onController?.(controller);
let path = "/api/langchain/rag/store";

View File

@@ -509,14 +509,13 @@ export function ChatActions(props: {
const [showUploadImage, setShowUploadImage] = useState(false);
const [showUploadFile, setShowUploadFile] = useState(false);
const accessStore = useAccessStore();
useEffect(() => {
const show = isVisionModel(currentModel);
setShowUploadImage(show);
const serverConfig = getServerSideConfig();
setShowUploadFile(
serverConfig.isEnableRAG && !show && isSupportRAGModel(currentModel),
);
const isEnableRAG = !!process.env.NEXT_PUBLIC_ENABLE_RAG;
setShowUploadFile(isEnableRAG && !show && isSupportRAGModel(currentModel));
if (!show) {
props.setAttachImages([]);
props.setUploading(false);
@@ -1039,7 +1038,9 @@ function _Chat() {
setIsLoading(true);
const textContent = getMessageTextContent(userMessage);
const images = getMessageImages(userMessage);
chatStore.onUserInput(textContent, images).then(() => setIsLoading(false));
chatStore
.onUserInput(textContent, images, userMessage.fileInfos)
.then(() => setIsLoading(false));
inputRef.current?.focus();
};

View File

@@ -113,5 +113,10 @@ export const getServerSideConfig = () => {
!process.env.S3_ENDPOINT,
isEnableRAG: !!process.env.NEXT_PUBLIC_ENABLE_RAG,
ragEmbeddingModel:
process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large",
ragChunkSize: process.env.RAG_CHUNK_SIZE ?? "2000",
ragChunkOverlap: process.env.RAG_CHUNK_OVERLAP ?? "200",
ragReturnCount: process.env.RAG_RETURN_COUNT ?? "4",
};
};

View File

@@ -43,6 +43,7 @@ const DEFAULT_ACCESS_STATE = {
disableGPT4: false,
disableFastLink: false,
customModels: "",
isEnableRAG: false,
};
export const useAccessStore = createPersistStore(
@@ -55,6 +56,10 @@ export const useAccessStore = createPersistStore(
return get().needCode;
},
isEnableRAG() {
return ensure(get(), ["isEnableRAG"]);
},
isValidOpenAI() {
return ensure(get(), ["openaiApiKey"]);
},

View File

@@ -376,88 +376,96 @@ export const useChatStore = createPersistStore(
});
var api: ClientApi;
api = new ClientApi(ModelProvider.GPT);
const isEnableRAG = !!process.env.NEXT_PUBLIC_ENABLE_RAG;
if (
config.pluginConfig.enable &&
session.mask.usePlugins &&
(allPlugins.length > 0 || !!process.env.NEXT_PUBLIC_ENABLE_RAG) &&
(allPlugins.length > 0 || isEnableRAG) &&
modelConfig.model.startsWith("gpt") &&
modelConfig.model != "gpt-4-vision-preview"
) {
console.log("[ToolAgent] start");
const pluginToolNames = allPlugins.map((m) => m.toolName);
if (!!process.env.NEXT_PUBLIC_ENABLE_RAG)
pluginToolNames.push("rag-search");
if (attachFiles && attachFiles.length > 0) {
console.log("crete rag store");
await api.llm.createRAGSore({
if (isEnableRAG) pluginToolNames.push("rag-search");
const agentCall = () => {
api.llm.toolAgentChat({
chatSessionId: session.id,
fileInfos: attachFiles,
});
}
api.llm.toolAgentChat({
chatSessionId: session.id,
messages: sendMessages,
config: { ...modelConfig, stream: true },
agentConfig: { ...pluginConfig, useTools: pluginToolNames },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onToolUpdate(toolName, toolInput) {
botMessage.streaming = true;
if (toolName && toolInput) {
botMessage.toolMessages!.push({
toolName,
toolInput,
messages: sendMessages,
config: { ...modelConfig, stream: true },
agentConfig: { ...pluginConfig, useTools: pluginToolNames },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
get().onNewMessage(botMessage);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onError(error) {
const isAborted = error.message.includes("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
},
onToolUpdate(toolName, toolInput) {
botMessage.streaming = true;
if (toolName && toolInput) {
botMessage.toolMessages!.push({
toolName,
toolInput,
});
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
get().onNewMessage(botMessage);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onError(error) {
const isAborted = error.message.includes("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
});
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
});
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
});
};
if (attachFiles && attachFiles.length > 0) {
await api.llm
.createRAGStore({
chatSessionId: session.id,
fileInfos: attachFiles,
})
.then(() => {
console.log("[RAG]", "Vector db created");
agentCall();
});
} else {
agentCall();
}
} else {
if (modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro);