feat: openai realtime merge

This commit is contained in:
Hk-Gosuto
2024-12-23 15:48:21 +08:00
parent c6156a8d8a
commit 21bf685d12
30 changed files with 2418 additions and 833 deletions

View File

@@ -1,13 +1,15 @@
import {
trimTopic,
getMessageTextContent,
isFunctionCallModel,
} from "../utils";
import { getMessageTextContent, trimTopic } from "../utils";
import Locale, { getLang } from "../locales";
import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
import { nanoid } from "nanoid";
import type {
ClientApi,
MultimodalContent,
RequestMessage,
} from "../client/api";
import { getClientApi } from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { showToast } from "../components/ui-lib";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { createEmptyMask, Mask } from "./mask";
import {
DEFAULT_INPUT_TEMPLATE,
DEFAULT_MODELS,
@@ -16,29 +18,24 @@ import {
StoreKey,
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
MYFILES_BROWSER_TOOLS_SYSTEM_PROMPT,
ServiceProvider,
} from "../constant";
import Locale, { getLang } from "../locales";
import { isDalle3, safeLocalStorage } from "../utils";
import { getClientApi } from "../client/api";
import type {
ClientApi,
RequestMessage,
MultimodalContent,
} from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
import { createPersistStore } from "../utils/store";
import { estimateTokenLength } from "../utils/token";
import { nanoid } from "nanoid";
import { Plugin, usePluginStore } from "../store/plugin";
import { ModelConfig, ModelType, useAppConfig } from "./config";
import { useAccessStore } from "./access";
import { collectModelsWithDefaultModel } from "../utils/model";
import { createEmptyMask, Mask } from "./mask";
import { FileInfo } from "../client/platforms/utils";
import { usePluginStore } from "./plugin";
export interface ChatToolMessage {
toolName: string;
toolInput?: string;
}
import { createPersistStore } from "../utils/store";
import { FileInfo } from "../client/platforms/utils";
import { collectModelsWithDefaultModel } from "../utils/model";
import { useAccessStore } from "./access";
const localStorage = safeLocalStorage();
@@ -52,6 +49,7 @@ export type ChatMessageTool = {
};
content?: string;
isError?: boolean;
errorMsg?: string;
};
export type ChatMessage = RequestMessage & {
@@ -61,6 +59,8 @@ export type ChatMessage = RequestMessage & {
isError?: boolean;
id: string;
model?: ModelType;
tools?: ChatMessageTool[];
audio_url?: string;
};
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -122,9 +122,12 @@ function createEmptySession(): ChatSession {
};
}
function getSummarizeModel(currentModel: string) {
function getSummarizeModel(
currentModel: string,
providerName: string,
): string[] {
// if it is using gpt-* models, force to use 4o-mini to summarize
if (currentModel.startsWith("gpt")) {
if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
const configStore = useAppConfig.getState();
const accessStore = useAccessStore.getState();
const allModel = collectModelsWithDefaultModel(
@@ -135,12 +138,17 @@ function getSummarizeModel(currentModel: string) {
const summarizeModel = allModel.find(
(m) => m.name === SUMMARIZE_MODEL && m.available,
);
return summarizeModel?.name ?? currentModel;
if (summarizeModel) {
return [
summarizeModel.name,
summarizeModel.provider?.providerName as string,
];
}
}
if (currentModel.startsWith("gemini")) {
return GEMINI_SUMMARIZE_MODEL;
return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
}
return currentModel;
return [currentModel, providerName];
}
function countMessages(msgs: ChatMessage[]) {
@@ -197,6 +205,7 @@ function fillTemplateWith(input: string, modelConfig: ModelConfig) {
const DEFAULT_CHAT_STATE = {
sessions: [createEmptySession()],
currentSessionIndex: 0,
lastInput: "",
};
export const useChatStore = createPersistStore(
@@ -210,6 +219,28 @@ export const useChatStore = createPersistStore(
}
const methods = {
forkSession() {
// 获取当前会话
const currentSession = get().currentSession();
if (!currentSession) return;
const newSession = createEmptySession();
newSession.topic = currentSession.topic;
newSession.messages = [...currentSession.messages];
newSession.mask = {
...currentSession.mask,
modelConfig: {
...currentSession.mask.modelConfig,
},
};
set((state) => ({
currentSessionIndex: 0,
sessions: [newSession, ...state.sessions],
}));
},
clearSessions() {
set(() => ({
sessions: [createEmptySession()],
@@ -335,13 +366,13 @@ export const useChatStore = createPersistStore(
return session;
},
onNewMessage(message: ChatMessage) {
get().updateCurrentSession((session) => {
onNewMessage(message: ChatMessage, targetSession: ChatSession) {
get().updateTargetSession(targetSession, (session) => {
session.messages = session.messages.concat();
session.lastUpdate = Date.now();
});
get().updateStat(message);
get().summarizeSession();
get().updateStat(message, targetSession);
get().summarizeSession(false, targetSession);
},
async onUserInput(
@@ -359,44 +390,39 @@ export const useChatStore = createPersistStore(
if (attachImages && attachImages.length > 0) {
mContent = [
{
type: "text",
text: userContent,
},
...(userContent
? [{ type: "text" as const, text: userContent }]
: []),
...attachImages.map((url) => ({
type: "image_url" as const,
image_url: { url },
})),
];
mContent = mContent.concat(
attachImages.map((url) => {
return {
type: "image_url",
image_url: {
url: url,
},
};
}),
);
}
// add file link
if (attachFiles && attachFiles.length > 0) {
mContent += ` [${attachFiles[0].originalFilename}](${attachFiles[0].filePath})`;
}
let userMessage: ChatMessage = createMessage({
role: "user",
content: mContent,
fileInfos: attachFiles,
});
const botMessage: ChatMessage = createMessage({
role: "assistant",
streaming: true,
model: modelConfig.model,
toolMessages: [],
});
const api: ClientApi = getClientApi(modelConfig.providerName);
const isEnableRAG =
session.attachFiles && session.attachFiles.length > 0;
// get recent messages
const recentMessages = get().getMessagesWithMemory();
const sendMessages = recentMessages.concat(userMessage);
const messageIndex = get().currentSession().messages.length + 1;
const messageIndex = session.messages.length + 1;
const config = useAppConfig.getState();
const pluginConfig = useAppConfig.getState().pluginConfig;
@@ -410,148 +436,86 @@ export const useChatStore = createPersistStore(
m.enable,
);
// save user's and bot's message
get().updateCurrentSession((session) => {
get().updateTargetSession(session, (session) => {
const savedUserMessage = {
...userMessage,
content: mContent,
};
session.messages.push(savedUserMessage);
session.messages.push(botMessage);
session.messages = session.messages.concat([
savedUserMessage,
botMessage,
]);
});
if (
config.pluginConfig.enable &&
session.mask.usePlugins &&
(allPlugins.length > 0 || isEnableRAG) &&
isFunctionCallModel(modelConfig.model)
) {
console.log("[ToolAgent] start");
let pluginToolNames = allPlugins.map((m) => m.toolName);
if (isEnableRAG) {
// other plugins will affect rag
// clear existing plugins here
pluginToolNames = [];
pluginToolNames.push("myfiles_browser");
}
const agentCall = () => {
api.llm.toolAgentChat({
chatSessionId: session.id,
messages: sendMessages,
config: { ...modelConfig, stream: true },
agentConfig: { ...pluginConfig, useTools: pluginToolNames },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onToolUpdate(toolName, toolInput) {
botMessage.streaming = true;
if (toolName && toolInput) {
botMessage.toolMessages!.push({
toolName,
toolInput,
});
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
get().onNewMessage(botMessage);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onError(error) {
const isAborted = error.message.includes("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
});
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
const api: ClientApi = getClientApi(modelConfig.providerName);
// make request
api.llm.chat({
messages: sendMessages,
config: { ...modelConfig, stream: true },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
}
get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat();
});
};
agentCall();
} else {
// make request
api.llm.chat({
messages: sendMessages,
config: { ...modelConfig, stream: true },
onUpdate(message) {
botMessage.streaming = true;
if (message) {
botMessage.content = message;
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
botMessage.date = new Date().toLocaleString();
get().onNewMessage(botMessage, session);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onBeforeTool(tool: ChatMessageTool) {
(botMessage.tools = botMessage?.tools || []).push(tool);
get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat();
});
},
onAfterTool(tool: ChatMessageTool) {
botMessage?.tools?.forEach((t, i, tools) => {
if (tool.id == t.id) {
tools[i] = { ...tool };
}
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat();
});
},
onError(error) {
const isAborted = error.message?.includes?.("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
});
},
onFinish(message) {
botMessage.streaming = false;
if (message) {
botMessage.content = message;
get().onNewMessage(botMessage);
}
ChatControllerPool.remove(session.id, botMessage.id);
},
onError(error) {
const isAborted = error.message.includes("aborted");
botMessage.content +=
"\n\n" +
prettyObject({
error: true,
message: error.message,
});
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateCurrentSession((session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
botMessage.streaming = false;
userMessage.isError = !isAborted;
botMessage.isError = !isAborted;
get().updateTargetSession(session, (session) => {
session.messages = session.messages.concat();
});
ChatControllerPool.remove(
session.id,
botMessage.id ?? messageIndex,
);
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
});
}
console.error("[Chat] failed ", error);
},
onController(controller) {
// collect controller for stop/retry
ChatControllerPool.addController(
session.id,
botMessage.id ?? messageIndex,
controller,
);
},
});
},
getMemoryPrompt() {
@@ -579,26 +543,17 @@ export const useChatStore = createPersistStore(
// system prompts, to get close to OpenAI Web ChatGPT
const shouldInjectSystemPrompts =
modelConfig.enableInjectSystemPrompts &&
session.mask.modelConfig.model.startsWith("gpt-");
(session.mask.modelConfig.model.startsWith("gpt-") ||
session.mask.modelConfig.model.startsWith("chatgpt-"));
var systemPrompts: ChatMessage[] = [];
var template = DEFAULT_SYSTEM_TEMPLATE;
if (session.attachFiles && session.attachFiles.length > 0) {
template += MYFILES_BROWSER_TOOLS_SYSTEM_PROMPT;
session.attachFiles.forEach((file) => {
template += `filename: \`${file.originalFilename}\`
partialDocument: \`\`\`
${file.partial}
\`\`\``;
});
}
systemPrompts = shouldInjectSystemPrompts
? [
createMessage({
role: "system",
content: fillTemplateWith("", {
...modelConfig,
template: template,
template: DEFAULT_SYSTEM_TEMPLATE,
}),
}),
]
@@ -674,23 +629,33 @@ ${file.partial}
set(() => ({ sessions }));
},
resetSession() {
get().updateCurrentSession((session) => {
resetSession(session: ChatSession) {
get().updateTargetSession(session, (session) => {
session.messages = [];
session.memoryPrompt = "";
});
},
summarizeSession() {
summarizeSession(
refreshTitle: boolean = false,
targetSession: ChatSession,
) {
const config = useAppConfig.getState();
const session = get().currentSession();
const session = targetSession;
const modelConfig = session.mask.modelConfig;
// skip summarize when using dalle3?
if (isDalle3(modelConfig.model)) {
return;
}
const api: ClientApi = getClientApi(modelConfig.providerName);
// if not config compressModel, then using getSummarizeModel
const [model, providerName] = modelConfig.compressModel
? [modelConfig.compressModel, modelConfig.compressProviderName]
: getSummarizeModel(
session.mask.modelConfig.model,
session.mask.modelConfig.providerName,
);
const api: ClientApi = getClientApi(providerName as ServiceProvider);
// remove error messages if any
const messages = session.messages;
@@ -698,29 +663,43 @@ ${file.partial}
// should summarize topic after chating more than 50 words
const SUMMARIZE_MIN_LEN = 50;
if (
!process.env.NEXT_PUBLIC_DISABLE_AUTOGENERATETITLE &&
config.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN
(!process.env.NEXT_PUBLIC_DISABLE_AUTOGENERATETITLE &&
config.enableAutoGenerateTitle &&
session.topic === DEFAULT_TOPIC &&
countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
refreshTitle
) {
const topicMessages = messages.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
const startIndex = Math.max(
0,
messages.length - modelConfig.historyMessageCount,
);
const topicMessages = messages
.slice(
startIndex < messages.length ? startIndex : messages.length - 1,
messages.length,
)
.concat(
createMessage({
role: "user",
content: Locale.Store.Prompt.Topic,
}),
);
api.llm.chat({
messages: topicMessages,
config: {
model: getSummarizeModel(session.mask.modelConfig.model),
model,
stream: false,
providerName,
},
onFinish(message) {
get().updateCurrentSession(
(session) =>
(session.topic =
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
);
onFinish(message, responseRes) {
if (responseRes?.status === 200) {
get().updateTargetSession(
session,
(session) =>
(session.topic =
message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
);
}
},
});
}
@@ -734,7 +713,7 @@ ${file.partial}
const historyMsgLength = countMessages(toBeSummarizedMsgs);
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
const n = toBeSummarizedMsgs.length;
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
Math.max(0, n - modelConfig.historyMessageCount),
@@ -775,17 +754,20 @@ ${file.partial}
config: {
...modelcfg,
stream: true,
model: getSummarizeModel(session.mask.modelConfig.model),
model,
providerName,
},
onUpdate(message) {
session.memoryPrompt = message;
},
onFinish(message) {
// console.log("[Memory] ", message);
get().updateCurrentSession((session) => {
session.lastSummarizeIndex = lastSummarizeIndex;
session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
});
onFinish(message, responseRes) {
if (responseRes?.status === 200) {
console.log("[Memory] ", message);
get().updateTargetSession(session, (session) => {
session.lastSummarizeIndex = lastSummarizeIndex;
session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
});
}
},
onError(err) {
console.error("[Summarize] ", err);
@@ -794,31 +776,39 @@ ${file.partial}
}
},
updateStat(message: ChatMessage) {
get().updateCurrentSession((session) => {
updateStat(message: ChatMessage, session: ChatSession) {
get().updateTargetSession(session, (session) => {
session.stat.charCount += message.content.length;
// TODO: should update chat count and word count
});
},
updateCurrentSession(updater: (session: ChatSession) => void) {
updateTargetSession(
targetSession: ChatSession,
updater: (session: ChatSession) => void,
) {
const sessions = get().sessions;
const index = get().currentSessionIndex;
const index = sessions.findIndex((s) => s.id === targetSession.id);
if (index < 0) return;
updater(sessions[index]);
set(() => ({ sessions }));
},
clearAllData() {
async clearAllData() {
await indexedDBStorage.clear();
localStorage.clear();
location.reload();
},
setLastInput(lastInput: string) {
set({
lastInput,
});
},
};
return methods;
},
{
name: StoreKey.Chat,
version: 3.1,
version: 3.3,
migrate(persistedState, version) {
const state = persistedState as any;
const newState = JSON.parse(
@@ -865,6 +855,24 @@ ${file.partial}
});
}
// add default summarize model for every session
if (version < 3.2) {
newState.sessions.forEach((s) => {
const config = useAppConfig.getState();
s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
s.mask.modelConfig.compressProviderName =
config.modelConfig.compressProviderName;
});
}
// revert default summarize model for every session
if (version < 3.3) {
newState.sessions.forEach((s) => {
const config = useAppConfig.getState();
s.mask.modelConfig.compressModel = "";
s.mask.modelConfig.compressProviderName = "";
});
}
return newState as any;
},
},

View File

@@ -17,6 +17,7 @@ import {
ServiceProvider,
} from "../constant";
import { createPersistStore } from "../utils/store";
import type { Voice } from "rt-client";
export type ModelType = (typeof DEFAULT_MODELS)[number]["name"];
export type TTSModelType = (typeof DEFAULT_TTS_MODELS)[number];
@@ -105,6 +106,19 @@ export const DEFAULT_CONFIG = {
enable: false,
engine: DEFAULT_STT_ENGINE,
},
realtimeConfig: {
enable: false,
provider: "OpenAI" as ServiceProvider,
model: "gpt-4o-realtime-preview-2024-10-01",
apiKey: "",
azure: {
endpoint: "",
deployment: "",
},
temperature: 0.9,
voice: "alloy" as Voice,
},
};
export type ChatConfig = typeof DEFAULT_CONFIG;
@@ -113,6 +127,7 @@ export type ModelConfig = ChatConfig["modelConfig"];
export type PluginConfig = ChatConfig["pluginConfig"];
export type TTSConfig = ChatConfig["ttsConfig"];
export type STTConfig = ChatConfig["sttConfig"];
export type RealtimeConfig = ChatConfig["realtimeConfig"];
export function limitNumber(
x: number,

View File

@@ -17,6 +17,14 @@ export type Plugin = {
builtin: boolean;
enable: boolean;
onlyNodeRuntime: boolean;
title: string;
version: string;
content: string;
authType?: string;
authLocation?: string;
authHeader?: string;
authToken?: string;
};
export const DEFAULT_PLUGIN_STATE = {