From 0385e60ce1b3c1ba5ea229402f15ea79da9ba4a1 Mon Sep 17 00:00:00 2001 From: RockYang Date: Sat, 14 Sep 2024 17:06:13 +0800 Subject: [PATCH] refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model --- api/core/app_server.go | 1 + api/core/types/chat.go | 1 + api/core/types/web.go | 7 +- api/handler/chat_app_type_handler.go | 3 +- api/handler/chatimpl/chat_handler.go | 25 +-- api/handler/chatimpl/openai_handler.go | 54 ++++-- api/handler/markmap_handler.go | 9 +- api/main.go | 5 + api/utils/net.go | 7 +- web/src/components/ChatPrompt.vue | 16 +- web/src/components/ChatSetting.vue | 5 +- web/src/store/sharedata.js | 7 +- web/src/utils/libs.js | 6 +- web/src/views/ChatPlus.vue | 43 +++-- web/src/views/MarkMap.vue | 5 +- web/src/views/admin/Login.vue | 54 +++--- web/src/views/mobile/ChatList.vue | 2 +- web/src/views/mobile/ChatSession.vue | 240 ++++++++++++------------- 18 files changed, 245 insertions(+), 245 deletions(-) diff --git a/api/core/app_server.go b/api/core/app_server.go index e911af72..9990ddba 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -205,6 +205,7 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/chat/list" || c.Request.URL.Path == "/api/app/list" || + c.Request.URL.Path == "/api/app/type/list" || c.Request.URL.Path == "/api/app/list/user" || c.Request.URL.Path == "/api/model/list" || c.Request.URL.Path == "/api/mj/imgWall" || diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 95a55397..63f18622 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -57,6 +57,7 @@ type ChatSession struct { ClientIP string `json:"client_ip"` // 客户端 IP ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 Model ChatModel `json:"model"` // GPT 模型 + Start int64 `json:"start"` // 开始请求时间戳 Tools []int `json:"tools"` // 工具函数列表 Stream bool `json:"stream"` // 是否采用流式输出 } diff --git a/api/core/types/web.go b/api/core/types/web.go index 8ca9b90f..eb9683bf 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -26,10 +26,9 @@ type ReplyMessage struct { type WsMsgType string const ( - WsStart = WsMsgType("start") - WsMiddle = WsMsgType("middle") - WsEnd = WsMsgType("end") - WsErr = WsMsgType("error") + WsContent = WsMsgType("content") // 输出内容 + WsEnd = WsMsgType("end") + WsErr = WsMsgType("error") ) // InputMessage 对话输入消息结构 diff --git a/api/handler/chat_app_type_handler.go b/api/handler/chat_app_type_handler.go index fc5e277e..c18f858a 100644 --- a/api/handler/chat_app_type_handler.go +++ b/api/handler/chat_app_type_handler.go @@ -6,6 +6,7 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" + "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -22,7 +23,7 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler func (h *ChatAppTypeHandler) List(c *gin.Context) { var items []model.AppType var appTypes = make([]vo.AppType, 0) - err := h.DB.Order("sort_num ASC").Find(&items).Error + err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error if err != nil { resp.ERROR(c, err.Error()) return diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 561a7918..8ea002fa 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -202,15 +202,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } var req = types.ApiRequest{ - Model: session.Model.Value, - Temperature: session.Model.Temperature, + Model: session.Model.Value, } // 兼容 GPT-O1 模型 if strings.HasPrefix(session.Model.Value, "o1-") { - req.MaxCompletionTokens = session.Model.MaxTokens + utils.ReplyContent(ws, "AI 正在思考...\n") req.Stream = false + session.Start = time.Now().Unix() } else { req.MaxTokens = session.Model.MaxTokens + req.Temperature = session.Model.Temperature req.Stream = session.Stream } @@ -449,7 +450,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi if err != nil { return nil, err } - logger.Debugf(utils.JsonEncode(req)) + logger.Debugf("对话请求消息体:%+v", req) apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) // 创建 HttpClient 请求对象 @@ -499,14 +500,6 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p } } -type Usage struct { - Prompt string - Content string - PromptTokens int - CompletionTokens int - TotalTokens int -} - func (h *ChatHandler) saveChatHistory( req types.ApiRequest, usage Usage, @@ -517,12 +510,8 @@ func (h *ChatHandler) saveChatHistory( userVo vo.User, promptCreatedAt time.Time, replyCreatedAt time.Time) { - if message.Role == "" { - message.Role = "assistant" - } - message.Content = usage.Content - useMsg := types.Message{Role: "user", Content: usage.Prompt} + useMsg := types.Message{Role: "user", Content: usage.Prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 if h.App.SysConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 @@ -573,7 +562,7 @@ func (h *ChatHandler) saveChatHistory( RoleId: role.Id, Type: types.ReplyMsg, Icon: role.Icon, - Content: message.Content, + Content: usage.Content, Tokens: replyTokens, TotalTokens: totalTokens, UseContext: true, diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index ccefe74f..c0e03a06 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -23,7 +23,15 @@ import ( "time" ) -type respVo struct { +type Usage struct { + Prompt string `json:"prompt,omitempty"` + Content string `json:"content,omitempty"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type OpenAIResVo struct { Id string `json:"id"` Object string `json:"object"` Created int `json:"created"` @@ -38,11 +46,7 @@ type respVo struct { Logprobs interface{} `json:"logprobs"` FinishReason string `json:"finish_reason"` } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` + Usage Usage `json:"usage"` } // OPenAI 消息发送实现 @@ -73,19 +77,19 @@ func (h *ChatHandler) sendOpenAiMessage( if response.StatusCode != 200 { body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body) + return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body)) } + contentType := response.Header.Get("Content-Type") if strings.Contains(contentType, "text/event-stream") { replyCreatedAt := time.Now() // 记录回复时间 // 循环读取 Chunk 消息 - var message = types.Message{} + var message = types.Message{Role: "assistant"} var contents = make([]string, 0) var function model.Function var toolCall = false var arguments = make([]string, 0) scanner := bufio.NewScanner(response.Body) - var isNew = true for scanner.Scan() { line := scanner.Text() if !strings.Contains(line, "data:") || len(line) < 30 { @@ -132,8 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage( if res.Error == nil { toolCall = true callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) - utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) - utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg}) + utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg}) contents = append(contents, callMsg) } continue @@ -150,12 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage( } else { content := responseBody.Choices[0].Delta.Content contents = append(contents, utils.InterfaceToString(content)) - if isNew { - utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) - isNew = false - } utils.ReplyChunkMessage(ws, types.ReplyMessage{ - Type: types.WsMiddle, + Type: types.WsContent, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) } @@ -188,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage( if errMsg != "" || apiRes.Code != types.Success { msg := "调用函数工具出错:" + apiRes.Message + errMsg utils.ReplyChunkMessage(ws, types.ReplyMessage{ - Type: types.WsMiddle, + Type: types.WsContent, Content: msg, }) contents = append(contents, msg) } else { utils.ReplyChunkMessage(ws, types.ReplyMessage{ - Type: types.WsMiddle, + Type: types.WsContent, Content: apiRes.Data, }) contents = append(contents, utils.InterfaceToString(apiRes.Data)) @@ -210,10 +209,27 @@ func (h *ChatHandler) sendOpenAiMessage( CompletionTokens: 0, TotalTokens: 0, } + message.Content = usage.Content h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } } else { // 非流式输出 - + var respVo OpenAIResVo + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("读取响应失败:%v", body) + } + err = json.Unmarshal(body, &respVo) + if err != nil { + return fmt.Errorf("解析响应失败:%v", body) + } + content := respVo.Choices[0].Message.Content + if strings.HasPrefix(req.Model, "o1-") { + content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) + } + utils.ReplyMessage(ws, content) + respVo.Usage.Prompt = prompt + respVo.Usage.Content = content + h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now()) } return nil diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index 9c624961..afdda797 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -86,6 +86,8 @@ func (h *MarkMapHandler) Client(c *gin.Context) { if err != nil { logger.Error(err) utils.ReplyErrorMessage(client, err.Error()) + } else { + utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd}) } } @@ -148,7 +150,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode if strings.Contains(contentType, "text/event-stream") { // 循环读取 Chunk 消息 scanner := bufio.NewScanner(response.Body) - var isNew = true for scanner.Scan() { line := scanner.Text() if !strings.Contains(line, "data:") || len(line) < 30 { @@ -169,12 +170,8 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode break } - if isNew { - utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart}) - isNew = false - } utils.ReplyChunkMessage(client, types.ReplyMessage{ - Type: types.WsMiddle, + Type: types.WsContent, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) } // end for diff --git a/api/main.go b/api/main.go index dd725536..544abc1e 100644 --- a/api/main.go +++ b/api/main.go @@ -512,6 +512,11 @@ func main() { group.POST("enable", h.Enable) group.POST("sort", h.Sort) }), + fx.Provide(handler.NewChatAppTypeHandler), + fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) { + group := s.Engine.Group("/api/app/type") + group.GET("list", h.List) + }), fx.Provide(handler.NewTestHandler), fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { group := s.Engine.Group("/api/test") diff --git a/api/utils/net.go b/api/utils/net.go index 127f0f51..74c1cb4b 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -33,11 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) { // ReplyMessage 回复客户端一条完整的消息 func ReplyMessage(ws *types.WsClient, message interface{}) { - ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) - ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message}) + ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message}) ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd}) } +func ReplyContent(ws *types.WsClient, message interface{}) { + ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message}) +} + // ReplyErrorMessage 向客户端发送错误消息 func ReplyErrorMessage(ws *types.WsClient, message interface{}) { ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message}) diff --git a/web/src/components/ChatPrompt.vue b/web/src/components/ChatPrompt.vue index 518ce2f4..5c540c80 100644 --- a/web/src/components/ChatPrompt.vue +++ b/web/src/components/ChatPrompt.vue @@ -132,12 +132,13 @@ const content =ref(processPrompt(props.data.content)) const files = ref([]) onMounted(() => { - // if (!finalTokens.value) { - // httpPost("/api/chat/tokens", {text: props.data.content, model: props.data.model}).then(res => { - // finalTokens.value = res.data; - // }).catch(() => { - // }) - // } + processFiles() +}) + +const processFiles = () => { + if (!props.data.content) { + return + } const linkRegex = /(https?:\/\/\S+)/g; const links = props.data.content.match(linkRegex); @@ -159,8 +160,7 @@ onMounted(() => { } content.value = md.render(content.value.trim()) -}) - +} const isExternalImg = (link, files) => { return isImage(link) && !files.find(file => file.url === link) } diff --git a/web/src/components/ChatSetting.vue b/web/src/components/ChatSetting.vue index 7672645b..4d7b4547 100644 --- a/web/src/components/ChatSetting.vue +++ b/web/src/components/ChatSetting.vue @@ -15,7 +15,9 @@ 对话样式 - + + + @@ -28,6 +30,7 @@ const store = useSharedStore(); const data = ref({ style: store.chatListStyle, + stream: store.chatStream, }) // eslint-disable-next-line no-undef const props = defineProps({ diff --git a/web/src/store/sharedata.js b/web/src/store/sharedata.js index 9207da04..c43f7a13 100644 --- a/web/src/store/sharedata.js +++ b/web/src/store/sharedata.js @@ -4,7 +4,8 @@ import Storage from 'good-storage' export const useSharedStore = defineStore('shared', { state: () => ({ showLoginDialog: false, - chatListStyle: Storage.get("chat_list_style","chat") + chatListStyle: Storage.get("chat_list_style","chat"), + chatStream: Storage.get("chat_stream",true), }), getters: {}, actions: { @@ -14,6 +15,10 @@ export const useSharedStore = defineStore('shared', { setChatListStyle(value) { this.chatListStyle = value; Storage.set("chat_list_style", value); + }, + setChatStream(value) { + this.chatStream = value; + Storage.set("chat_stream", value); } } }); diff --git a/web/src/utils/libs.js b/web/src/utils/libs.js index 5e55525f..06b65d5c 100644 --- a/web/src/utils/libs.js +++ b/web/src/utils/libs.js @@ -9,8 +9,6 @@ * Util lib functions */ import {showConfirmDialog} from "vant"; -import {httpDownload} from "@/utils/http"; -import {showMessageError} from "@/utils/dialog"; // generate a random string export function randString(length) { @@ -183,6 +181,10 @@ export function isImage(url) { } export function processContent(content) { + if (!content) { + return "" + } + // 如果是图片链接地址,则直接替换成图片标签 const linkRegex = /(https?:\/\/\S+)/g; const links = content.match(linkRegex); diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 389c2b1a..ecf61b76 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -106,7 +106,7 @@ - + @@ -271,6 +271,12 @@ watch(() => store.chatListStyle, (newValue) => { const tools = ref([]) const toolSelected = ref([]) const loadHistory = ref(false) +const stream = ref(store.chatStream) + +watch(() => store.chatStream, (newValue) => { + stream.value = newValue +}); + // 初始化角色ID参数 if (router.currentRoute.value.query.role_id) { @@ -491,16 +497,6 @@ const newChat = () => { connect() } -// 切换工具 -const changeTool = () => { - if (!isLogin.value) { - return; - } - loadHistory.value = false - socket.value.close() -} - - // 切换会话 const loadChat = function (chat) { if (!isLogin.value) { @@ -598,6 +594,7 @@ const lineBuffer = ref(''); // 输出缓冲行 const socket = ref(null); const canSend = ref(true); const sessionId = ref("") +const isNewMsg = ref(true) const connect = function () { const chatRole = getRoleById(roleId.value); // 初始化 WebSocket 对象 @@ -612,8 +609,7 @@ const connect = function () { } loading.value = true - const toolIds = toolSelected.value.join(',') - const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}&tools=${toolIds}`); + const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}`); _socket.addEventListener('open', () => { enableInput() if (loadHistory.value) { @@ -629,15 +625,22 @@ const connect = function () { reader.readAsText(event.data, "UTF-8"); reader.onload = () => { const data = JSON.parse(String(reader.result)); - if (data.type === 'start') { + if (data.type === 'error') { + ElMessage.error(data.message) + return + } + + if (isNewMsg.value && data.type !== 'end') { const prePrompt = chatData.value[chatData.value.length-1]?.content chatData.value.push({ type: "reply", id: randString(32), icon: chatRole['icon'], prompt:prePrompt, - content: "", + content: data.content, }); + isNewMsg.value = false + lineBuffer.value = data.content; } else if (data.type === 'end') { // 消息接收完毕 // 追加当前会话到会话列表 if (newChatItem.value !== null) { @@ -663,6 +666,7 @@ const connect = function () { nextTick(() => { document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight) }) + isNewMsg.value = true }).catch(() => { }) @@ -688,6 +692,7 @@ const connect = function () { _socket.addEventListener('close', () => { disableInput(true) + loadHistory.value = false connect() }); @@ -775,7 +780,7 @@ const sendMessage = function () { showHello.value = false disableInput(false) - socket.value.send(JSON.stringify({type: "chat", content: content})); + socket.value.send(JSON.stringify({tools: toolSelected.value, content: content, stream: stream.value})); tmpChatTitle.value = content prompt.value = '' files.value = [] @@ -813,7 +818,7 @@ const loadChatHistory = function (chatId) { chatData.value = [] httpGet('/api/chat/history?chat_id=' + chatId).then(res => { const data = res.data - if (!data || data.length === 0) { // 加载打招呼信息 + if ((!data || data.length === 0) && chatData.value.length === 0) { // 加载打招呼信息 const _role = getRoleById(roleId.value) chatData.value.push({ chat_id: chatId, @@ -852,7 +857,7 @@ const stopGenerate = function () { // 重新生成 const reGenerate = function (prompt) { disableInput(false) - const text = '重新生成下面问题的答案:' + prompt; + const text = '重新回答下述问题:' + prompt; // 追加消息 chatData.value.push({ type: "prompt", @@ -860,7 +865,7 @@ const reGenerate = function (prompt) { icon: loginUser.value.avatar, content: text }); - socket.value.send(JSON.stringify({type: "chat", content: prompt})); + socket.value.send(JSON.stringify({tools: toolSelected.value, content: text, stream: stream.value})); } const chatName = ref('') diff --git a/web/src/views/MarkMap.vue b/web/src/views/MarkMap.vue index 97c1f5aa..d2bef90c 100644 --- a/web/src/views/MarkMap.vue +++ b/web/src/views/MarkMap.vue @@ -231,10 +231,7 @@ const connect = (userId) => { reader.onload = () => { const data = JSON.parse(String(reader.result)) switch (data.type) { - case "start": - text.value = "" - break - case "middle": + case "content": text.value += data.content html.value = md.render(processContent(text.value)) break diff --git a/web/src/views/admin/Login.vue b/web/src/views/admin/Login.vue index e002b02d..aa0ee49d 100644 --- a/web/src/views/admin/Login.vue +++ b/web/src/views/admin/Login.vue @@ -2,46 +2,38 @@