This commit is contained in:
Hk-Gosuto
2024-04-07 18:00:21 +08:00
parent 7382ce48bb
commit b00e9f0c79
17 changed files with 307 additions and 122 deletions

View File

@@ -376,88 +376,96 @@ export const useChatStore = createPersistStore(
});
var api: ClientApi;
api = new ClientApi(ModelProvider.GPT);
const isEnableRAG = !!process.env.NEXT_PUBLIC_ENABLE_RAG;
if (
config.pluginConfig.enable &&
session.mask.usePlugins &&
(allPlugins.length > 0 || !!process.env.NEXT_PUBLIC_ENABLE_RAG) &&
(allPlugins.length > 0 || isEnableRAG) &&
modelConfig.model.startsWith("gpt") &&
modelConfig.model != "gpt-4-vision-preview"
) {
console.log("[ToolAgent] start");
const pluginToolNames = allPlugins.map((m) => m.toolName);
if (!!process.env.NEXT_PUBLIC_ENABLE_RAG)
pluginToolNames.push("rag-search");
if (attachFiles && attachFiles.length > 0) {
console.log("crete rag store");
await api.llm.createRAGSore({
if (isEnableRAG) pluginToolNames.push("rag-search");
const agentCall = () => {
api.llm.toolAgentChat({
chatSessionId: session.id,
fileInfos: attachFiles,
});
}
api.llm.toolAgentChat({
chatSessionId: session.id,
messages: sendMessages,
config: { ...modelConfig, stream: true },
agentConfig: { ...pluginConfig, useTools: pluginToolNames },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onToolUpdate(toolName, toolInput) {
botMessage.streaming = true;
if (toolName && toolInput) {
botMessage.toolMessages!.push({
toolName,
toolInput,
messages: sendMessages,
config: { ...modelConfig, stream: true },
agentConfig: { ...pluginConfig, useTools: pluginToolNames },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
get().onNewMessage(botMessage);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onError(error) {
const isAborted = error.message.includes("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
},
onToolUpdate(toolName, toolInput) {
botMessage.streaming = true;
if (toolName && toolInput) {
botMessage.toolMessages!.push({
toolName,
toolInput,
});
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
get().onNewMessage(botMessage);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onError(error) {
const isAborted = error.message.includes("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
});
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
});
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
});
};
if (attachFiles && attachFiles.length > 0) {
await api.llm
.createRAGStore({
chatSessionId: session.id,
fileInfos: attachFiles,
})
.then(() => {
console.log("[RAG]", "Vector db created");
agentCall();
});
} else {
agentCall();
}
} else {
if (modelConfig.model.startsWith("gemini")) {
api = new ClientApi(ModelProvider.GeminiPro);