diff --git a/app/client/api.ts b/app/client/api.ts index cecc453ba..671cb1c48 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -59,7 +59,7 @@ export interface ChatOptions { config: LLMConfig; onUpdate?: (message: string, chunk: string) => void; - onFinish: (message: string) => void; + onFinish: (message: string, finishedReason?: string) => void; onError?: (err: Error) => void; onController?: (controller: AbortController) => void; onBeforeTool?: (tool: ChatMessageTool) => void; diff --git a/app/client/controller.ts b/app/client/controller.ts index a2e00173d..ac5bac7a1 100644 --- a/app/client/controller.ts +++ b/app/client/controller.ts @@ -26,6 +26,10 @@ export const ChatControllerPool = { return Object.values(this.controllers).length > 0; }, + getPendingMessageId() { + return Object.keys(this.controllers).map((v) => v.split(",").at(-1)); + }, + remove(sessionId: string, messageId: string) { const key = this.key(sessionId, messageId); delete this.controllers[key]; diff --git a/app/client/platforms/anthropic.ts b/app/client/platforms/anthropic.ts index 7dd39c9cd..eb3ab8f48 100644 --- a/app/client/platforms/anthropic.ts +++ b/app/client/platforms/anthropic.ts @@ -262,7 +262,7 @@ export class ClaudeApi implements LLMApi { runTools[index]["function"]["arguments"] += chunkJson?.delta?.partial_json; } - return chunkJson?.delta?.text; + return { delta: chunkJson?.delta?.text }; }, // processToolMessage, include tool_calls message and tool call results ( diff --git a/app/client/platforms/moonshot.ts b/app/client/platforms/moonshot.ts index cd10d2f6c..3b5590a9f 100644 --- a/app/client/platforms/moonshot.ts +++ b/app/client/platforms/moonshot.ts @@ -163,7 +163,7 @@ export class MoonshotApi implements LLMApi { runTools[index]["function"]["arguments"] += args; } } - return choices[0]?.delta?.content; + return { delta: choices[0]?.delta?.content }; }, // processToolMessage, include tool_calls message and tool call results ( diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 664ff872b..aff9a2ca2 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -266,6 +266,7 @@ export class ChatGPTApi implements LLMApi { content: string; tool_calls: ChatMessageTool[]; }; + finish_reason?: string; }>; const tool_calls = choices[0]?.delta?.tool_calls; if (tool_calls?.length > 0) { @@ -286,7 +287,10 @@ export class ChatGPTApi implements LLMApi { runTools[index]["function"]["arguments"] += args; } } - return choices[0]?.delta?.content; + return { + delta: choices[0]?.delta?.content, + finishReason: choices[0]?.finish_reason, + }; }, // processToolMessage, include tool_calls message and tool call results ( diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 17f8d3a34..b2b35c9ce 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -9,6 +9,7 @@ import React, { RefObject, } from "react"; +import ContinueIcon from "../icons/continue.svg"; import SendWhiteIcon from "../icons/send-white.svg"; import BrainIcon from "../icons/brain.svg"; import RenameIcon from "../icons/rename.svg"; @@ -461,7 +462,16 @@ export function ChatActions(props: { // stop all responses const couldStop = ChatControllerPool.hasPending(); - const stopAll = () => ChatControllerPool.stopAll(); + const stopAll = () => { + const stopList = ChatControllerPool.getPendingMessageId(); + ChatControllerPool.stopAll(); + chatStore.updateCurrentSession( + (session) => + (session.messages = session.messages.map((v) => + stopList.includes(v.id) ? { ...v, finishedReason: "aborted" } : v, + )), + ); + }; // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; @@ -1045,6 +1055,12 @@ function _Chat() { // stop response const onUserStop = (messageId: string) => { ChatControllerPool.stop(session.id, messageId); + chatStore.updateCurrentSession( + (session) => + (session.messages = session.messages.map((v) => + v.id === messageId ? { ...v, finishedReason: "aborted" } : v, + )), + ); }; useEffect(() => { @@ -1171,6 +1187,18 @@ function _Chat() { inputRef.current?.focus(); }; + const onContinue = (messageID: string) => { + chatStore.updateCurrentSession( + (session) => + (session.messages = session.messages.map((v) => + v.id === messageID ? { ...v, streaming: true } : v, + )), + ); + chatStore + .onContinueBotMessage(messageID) + .finally(() => setIsLoading(false)); + }; + const onPinMessage = (message: ChatMessage) => { chatStore.updateCurrentSession((session) => session.mask.context.push(message), @@ -1724,6 +1752,15 @@ function _Chat() { ) } /> + {["length", "aborted"].includes( + message.finishedReason ?? "", + ) ? ( + } + onClick={() => onContinue(message.id)} + /> + ) : null} )} diff --git a/app/icons/continue.svg b/app/icons/continue.svg new file mode 100644 index 000000000..d88f263f6 --- /dev/null +++ b/app/icons/continue.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/store/chat.ts b/app/store/chat.ts index 3bcda7538..59c22f149 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -46,6 +46,7 @@ export type ChatMessage = RequestMessage & { id: string; model?: ModelType; tools?: ChatMessageTool[]; + finishedReason?: string; }; export function createMessage(override: Partial): ChatMessage { @@ -373,8 +374,10 @@ export const useChatStore = createPersistStore( session.messages = session.messages.concat(); }); }, - onFinish(message) { + onFinish(message, finishedReason) { botMessage.streaming = false; + if (finishedReason !== null && finishedReason !== undefined) + botMessage.finishedReason = finishedReason; if (message) { botMessage.content = message; get().onNewMessage(botMessage); @@ -429,6 +432,94 @@ export const useChatStore = createPersistStore( }); }, + async onContinueBotMessage(messageID: string) { + const session = get().currentSession(); + const modelConfig = session.mask.modelConfig; + + // get recent messages + const recentMessages = get().getMessagesWithMemory(messageID); + const messageIndex = get().currentSession().messages.length + 1; + + const botMessage = session.messages.find((v) => v.id === messageID); + + if (!botMessage) { + console.error("[Chat] failed to find bot message"); + return; + } + + const baseContent = botMessage.content; + + const api: ClientApi = getClientApi(modelConfig.providerName); + // make request + api.llm.chat({ + messages: recentMessages, + config: { ...modelConfig, stream: true }, + onUpdate(message) { + botMessage.streaming = true; + if (message) { + botMessage.content = baseContent + message; + } + get().updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + }, + onFinish(message, finishedReason) { + botMessage.streaming = false; + if (finishedReason !== null && finishedReason !== undefined) + botMessage.finishedReason = finishedReason; + if (message) { + botMessage.content = baseContent + message; + get().onNewMessage(botMessage); + } + ChatControllerPool.remove(session.id, botMessage.id); + }, + onBeforeTool(tool: ChatMessageTool) { + (botMessage.tools = botMessage?.tools || []).push(tool); + get().updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + }, + onAfterTool(tool: ChatMessageTool) { + botMessage?.tools?.forEach((t, i, tools) => { + if (tool.id == t.id) { + tools[i] = { ...tool }; + } + }); + get().updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + }, + onError(error) { + const isAborted = error.message?.includes?.("aborted"); + botMessage.content += + "\n\n" + + prettyObject({ + error: true, + message: error.message, + }); + botMessage.streaming = false; + botMessage.isError = !isAborted; + get().updateCurrentSession((session) => { + session.messages = session.messages.concat(); + }); + ChatControllerPool.remove( + session.id, + botMessage.id ?? messageIndex, + ); + + console.error("[Chat] failed ", error); + }, + onController(controller) { + // collect controller for stop/retry + ChatControllerPool.addController( + session.id, + botMessage.id ?? messageIndex, + controller, + ); + }, + }); + }, + getMemoryPrompt() { const session = get().currentSession(); @@ -441,12 +532,17 @@ export const useChatStore = createPersistStore( } }, - getMessagesWithMemory() { + getMessagesWithMemory(messageID?: string) { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; const clearContextIndex = session.clearContextIndex ?? 0; const messages = session.messages.slice(); - const totalMessageCount = session.messages.length; + let messageIdx = session.messages.findIndex((v) => v.id === messageID); + if (messageIdx === -1) messageIdx = session.messages.length; + const totalMessageCount = Math.min( + messageIdx + 1, + session.messages.length, + ); // in-context prompts const contextPrompts = session.mask.context.slice(); diff --git a/app/utils/chat.ts b/app/utils/chat.ts index 7f3bb23c5..8c04df4fc 100644 --- a/app/utils/chat.ts +++ b/app/utils/chat.ts @@ -3,7 +3,7 @@ import { UPLOAD_URL, REQUEST_TIMEOUT_MS, } from "@/app/constant"; -import { RequestMessage } from "@/app/client/api"; +import { ChatOptions, RequestMessage } from "@/app/client/api"; import Locale from "@/app/locales"; import { EventStreamContentType, @@ -160,17 +160,21 @@ export function stream( tools: any[], funcs: Record, controller: AbortController, - parseSSE: (text: string, runTools: any[]) => string | undefined, + parseSSE: ( + text: string, + runTools: any[], + ) => { delta?: string; finishReason?: string }, processToolMessage: ( requestPayload: any, toolCallMessage: any, toolCallResult: any[], ) => void, - options: any, + options: ChatOptions, ) { let responseText = ""; let remainText = ""; let finished = false; + let finishedReason: string | undefined; let running = false; let runTools: any[] = []; @@ -254,14 +258,13 @@ export function stream( chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource }, 60); }); - return; } if (running) { return; } console.debug("[ChatAPI] end"); finished = true; - options.onFinish(responseText + remainText); + options.onFinish(responseText + remainText, finishedReason); } }; @@ -333,7 +336,11 @@ export function stream( try { const chunk = parseSSE(msg.data, runTools); if (chunk) { - remainText += chunk; + if (typeof chunk === "string") remainText += chunk; + else { + if (chunk.delta) remainText += chunk.delta; + finishedReason = chunk.finishReason; + } } } catch (e) { console.error("[Request] parse error", text, msg, e);