From 131efd6ba5d4a795e5c3cf737beb8d20d48690bf Mon Sep 17 00:00:00 2001 From: RockYang Date: Sat, 14 Sep 2024 07:11:45 +0800 Subject: [PATCH] refactor chat message body struct --- CHANGELOG.md | 2 + api/core/types/chat.go | 19 ++-- api/core/types/web.go | 11 ++- api/handler/chatimpl/chat_handler.go | 122 ++++++++++++++----------- api/handler/chatimpl/openai_handler.go | 52 +++++++++-- api/handler/dalle_handler.go | 2 +- api/handler/markmap_handler.go | 10 +- api/store/model/chat_history.go | 21 +++-- api/store/vo/chat_role.go | 2 +- api/utils/net.go | 11 ++- database/update-v4.1.4.sql | 4 +- 11 files changed, 161 insertions(+), 95 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 857a200c..c119bd0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ * 功能优化:用户文件列表组件增加分页功能支持 * Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码 * 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码 +* 功能新增:给 AI 应用(角色)增加分类 +* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。 ## v4.1.3 * 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信 diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 42c86a2b..95a55397 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -9,14 +9,14 @@ package types // ApiRequest API 请求实体 type ApiRequest struct { - Model string `json:"model,omitempty"` // 兼容百度文心一言 - Temperature float32 `json:"temperature"` - MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言 - Stream bool `json:"stream"` - Messages []interface{} `json:"messages,omitempty"` - Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM - Tools []Tool `json:"tools,omitempty"` - Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 + Model string `json:"model,omitempty"` + Temperature float32 `json:"temperature"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型 + Stream bool `json:"stream,omitempty"` + Messages []interface{} `json:"messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 ToolChoice string `json:"tool_choice,omitempty"` @@ -57,7 +57,8 @@ type ChatSession struct { ClientIP string `json:"client_ip"` // 客户端 IP ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 Model ChatModel `json:"model"` // GPT 模型 - Tools string `json:"tools"` // 函数 + Tools []int `json:"tools"` // 工具函数列表 + Stream bool `json:"stream"` // 是否采用流式输出 } type ChatModel struct { diff --git a/api/core/types/web.go b/api/core/types/web.go index 408d9a58..8ca9b90f 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -17,8 +17,8 @@ type BizVo struct { Data interface{} `json:"data,omitempty"` } -// WsMessage Websocket message -type WsMessage struct { +// ReplyMessage 对话回复消息结构 +type ReplyMessage struct { Type WsMsgType `json:"type"` // 消息类别,start, end, img Content interface{} `json:"content"` } @@ -32,6 +32,13 @@ const ( WsErr = WsMsgType("error") ) +// InputMessage 对话输入消息结构 +type InputMessage struct { + Content string `json:"content"` + Tools []int `json:"tools"` // 允许调用工具列表 + Stream bool `json:"stream"` // 是否采用流式输出 +} + type BizCode int const ( diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 730043e9..561a7918 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -73,13 +73,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { roleId := h.GetInt(c, "role_id", 0) chatId := c.Query("chat_id") modelId := h.GetInt(c, "model_id", 0) - tools := c.Query("tools") client := types.NewWsClient(ws) var chatRole model.ChatRole res := h.DB.First(&chatRole, roleId) if res.Error != nil || !chatRole.Enable { - utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") + utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!") c.Abort() return } @@ -91,7 +90,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { var chatModel model.ChatModel res = h.DB.First(&chatModel, modelId) if res.Error != nil || chatModel.Enabled == false { - utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!") + utils.ReplyErrorMessage(client, "当前AI模型暂未启用,对话已关闭!!!") c.Abort() return } @@ -100,7 +99,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { SessionId: sessionId, ClientIP: c.ClientIP(), UserId: h.GetLoginUserId(c), - Tools: tools, } // use old chat data override the chat model and role ID @@ -137,20 +135,16 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { return } - var message types.WsMessage + var message types.InputMessage err = utils.JsonDecode(string(msg), &message) if err != nil { continue } - // 心跳消息 - if message.Type == "heartbeat" { - logger.Debug("收到 Chat 心跳消息:", message.Content) - continue - } - - logger.Info("Receive a message: ", message.Content) + logger.Infof("Receive a message:%+v", message) + session.Tools = message.Tools + session.Stream = message.Stream ctx, cancel := context.WithCancel(context.Background()) h.ReqCancelFunc.Put(sessionId, cancel) // 回复消息 @@ -159,7 +153,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { logger.Error(err) utils.ReplyMessage(client, err.Error()) } else { - utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd}) logger.Infof("回答完毕: %v", message.Content) } @@ -208,16 +202,21 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } var req = types.ApiRequest{ - Model: session.Model.Value, - Stream: true, + Model: session.Model.Value, + Temperature: session.Model.Temperature, + } + // 兼容 GPT-O1 模型 + if strings.HasPrefix(session.Model.Value, "o1-") { + req.MaxCompletionTokens = session.Model.MaxTokens + req.Stream = false + } else { + req.MaxTokens = session.Model.MaxTokens + req.Stream = session.Stream } - req.Temperature = session.Model.Temperature - req.MaxTokens = session.Model.MaxTokens - if session.Tools != "" { - toolIds := strings.Split(session.Tools, ",") + if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") { var items []model.Function - res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items) + res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items) if res.Error == nil { var tools = make([]types.Tool, 0) for _, v := range items { @@ -279,7 +278,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio for i := len(messages) - 1; i >= 0; i-- { v := messages[i] - tks, _ := utils.CalcTokens(v.Content, req.Model) + tks, _ = utils.CalcTokens(v.Content, req.Model) // 上下文 token 超出了模型的最大上下文长度 if tokens+tks >= session.Model.MaxContext { break @@ -500,10 +499,17 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p } } +type Usage struct { + Prompt string + Content string + PromptTokens int + CompletionTokens int + TotalTokens int +} + func (h *ChatHandler) saveChatHistory( req types.ApiRequest, - prompt string, - contents []string, + usage Usage, message types.Message, chatCtx []types.Message, session *types.ChatSession, @@ -514,8 +520,8 @@ func (h *ChatHandler) saveChatHistory( if message.Role == "" { message.Role = "assistant" } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} + message.Content = usage.Content + useMsg := types.Message{Role: "user", Content: usage.Prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 if h.App.SysConfig.EnableContext { @@ -526,42 +532,52 @@ func (h *ChatHandler) saveChatHistory( // 追加聊天记录 // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) + var promptTokens, replyTokens, totalTokens int + if usage.PromptTokens > 0 { + promptTokens = usage.PromptTokens + } else { + promptTokens, _ = utils.CalcTokens(usage.Content, req.Model) } + historyUserMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: template.HTMLEscapeString(prompt), - Tokens: promptToken, - UseContext: true, - Model: req.Model, + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(usage.Prompt), + Tokens: promptTokens, + TotalTokens: promptTokens, + UseContext: true, + Model: req.Model, } historyUserMsg.CreatedAt = promptCreatedAt historyUserMsg.UpdatedAt = promptCreatedAt - err = h.DB.Save(&historyUserMsg).Error + err := h.DB.Save(&historyUserMsg).Error if err != nil { logger.Error("failed to save prompt history message: ", err) } // for reply // 计算本次对话消耗的总 token 数量 - replyTokens, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyTokens + getTotalTokens(req) + if usage.CompletionTokens > 0 { + replyTokens = usage.CompletionTokens + totalTokens = usage.TotalTokens + } else { + replyTokens, _ = utils.CalcTokens(message.Content, req.Model) + totalTokens = replyTokens + getTotalTokens(req) + } historyReplyMsg := model.ChatMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - Model: req.Model, + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: replyTokens, + TotalTokens: totalTokens, + UseContext: true, + Model: req.Model, } historyReplyMsg.CreatedAt = replyCreatedAt historyReplyMsg.UpdatedAt = replyCreatedAt @@ -572,7 +588,7 @@ func (h *ChatHandler) saveChatHistory( // 更新用户算力 if session.Model.Power > 0 { - h.subUserPower(userVo, session, promptToken, replyTokens) + h.subUserPower(userVo, session, promptTokens, replyTokens) } // 保存当前会话 var chatItem model.ChatItem @@ -582,10 +598,10 @@ func (h *ChatHandler) saveChatHistory( chatItem.UserId = userVo.Id chatItem.RoleId = role.Id chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." + if utf8.RuneCountInString(usage.Prompt) > 30 { + chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..." } else { - chatItem.Title = prompt + chatItem.Title = usage.Prompt } chatItem.Model = req.Model err = h.DB.Create(&chatItem).Error diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index 775c8275..ccefe74f 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -23,6 +23,28 @@ import ( "time" ) +type respVo struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + // OPenAI 消息发送实现 func (h *ChatHandler) sendOpenAiMessage( chatCtx []types.Message, @@ -49,6 +71,10 @@ func (h *ChatHandler) sendOpenAiMessage( defer response.Body.Close() } + if response.StatusCode != 200 { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body) + } contentType := response.Header.Get("Content-Type") if strings.Contains(contentType, "text/event-stream") { replyCreatedAt := time.Now() // 记录回复时间 @@ -106,8 +132,8 @@ func (h *ChatHandler) sendOpenAiMessage( if res.Error == nil { toolCall = true callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg}) + utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) + utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg}) contents = append(contents, callMsg) } continue @@ -125,10 +151,10 @@ func (h *ChatHandler) sendOpenAiMessage( content := responseBody.Choices[0].Delta.Content contents = append(contents, utils.InterfaceToString(content)) if isNew { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) isNew = false } - utils.ReplyChunkMessage(ws, types.WsMessage{ + utils.ReplyChunkMessage(ws, types.ReplyMessage{ Type: types.WsMiddle, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) @@ -161,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage( } if errMsg != "" || apiRes.Code != types.Success { msg := "调用函数工具出错:" + apiRes.Message + errMsg - utils.ReplyChunkMessage(ws, types.WsMessage{ + utils.ReplyChunkMessage(ws, types.ReplyMessage{ Type: types.WsMiddle, Content: msg, }) contents = append(contents, msg) } else { - utils.ReplyChunkMessage(ws, types.WsMessage{ + utils.ReplyChunkMessage(ws, types.ReplyMessage{ Type: types.WsMiddle, Content: apiRes.Data, }) @@ -177,11 +203,17 @@ func (h *ChatHandler) sendOpenAiMessage( // 消息发送成功 if len(contents) > 0 { - h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + usage := Usage{ + Prompt: prompt, + Content: strings.Join(contents, ""), + PromptTokens: 0, + CompletionTokens: 0, + TotalTokens: 0, + } + h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) } - } else { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求 OpenAI API 失败:%s", body) + } else { // 非流式输出 + } return nil diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index bcf44ba8..80b993ee 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -73,7 +73,7 @@ func (h *DallJobHandler) Client(c *gin.Context) { return } - var message types.WsMessage + var message types.ReplyMessage err = utils.JsonDecode(string(msg), &message) if err != nil { continue diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index b4147deb..9c624961 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -64,7 +64,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) { return } - var message types.WsMessage + var message types.ReplyMessage err = utils.JsonDecode(string(msg), &message) if err != nil { continue @@ -85,7 +85,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) { err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) if err != nil { logger.Error(err) - utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()}) + utils.ReplyErrorMessage(client, err.Error()) } } @@ -170,16 +170,16 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode } if isNew { - utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) + utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart}) isNew = false } - utils.ReplyChunkMessage(client, types.WsMessage{ + utils.ReplyChunkMessage(client, types.ReplyMessage{ Type: types.WsMiddle, Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) } // end for - utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd}) } else { body, _ := io.ReadAll(response.Body) diff --git a/api/store/model/chat_history.go b/api/store/model/chat_history.go index 36abeb4e..876c427f 100644 --- a/api/store/model/chat_history.go +++ b/api/store/model/chat_history.go @@ -4,16 +4,17 @@ import "gorm.io/gorm" type ChatMessage struct { BaseModel - ChatId string // 会话 ID - UserId uint // 用户 ID - RoleId uint // 角色 ID - Model string // AI模型 - Type string - Icon string - Tokens int - Content string - UseContext bool // 是否可以作为聊天上下文 - DeletedAt gorm.DeletedAt + ChatId string // 会话 ID + UserId uint // 用户 ID + RoleId uint // 角色 ID + Model string // AI模型 + Type string + Icon string + Tokens int + TotalTokens int // 总 token 消耗 + Content string + UseContext bool // 是否可以作为聊天上下文 + DeletedAt gorm.DeletedAt } func (ChatMessage) TableName() string { diff --git a/api/store/vo/chat_role.go b/api/store/vo/chat_role.go index ad82d949..9ab49cf6 100644 --- a/api/store/vo/chat_role.go +++ b/api/store/vo/chat_role.go @@ -5,7 +5,7 @@ import "geekai/core/types" type ChatRole struct { BaseVo Key string `json:"key"` // 角色唯一标识 - Tid uint `json:"tid"` + Tid int `json:"tid"` Name string `json:"name"` // 角色名称 Context []types.Message `json:"context"` // 角色语料信息 HelloMsg string `json:"hello_msg"` // 打招呼的消息 diff --git a/api/utils/net.go b/api/utils/net.go index 5f02922c..127f0f51 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -33,9 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) { // ReplyMessage 回复客户端一条完整的消息 func ReplyMessage(ws *types.WsClient, message interface{}) { - ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message}) - ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) + ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) + ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message}) + ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd}) +} + +// ReplyErrorMessage 向客户端发送错误消息 +func ReplyErrorMessage(ws *types.WsClient, message interface{}) { + ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message}) } func DownloadImage(imageURL string, proxy string) ([]byte, error) { diff --git a/database/update-v4.1.4.sql b/database/update-v4.1.4.sql index 86d47b63..28d93e5c 100644 --- a/database/update-v4.1.4.sql +++ b/database/update-v4.1.4.sql @@ -9,4 +9,6 @@ CREATE TABLE `chatgpt_app_types` ( ALTER TABLE `chatgpt_app_types`ADD PRIMARY KEY (`id`); ALTER TABLE `chatgpt_app_types` MODIFY `id` int NOT NULL AUTO_INCREMENT; -ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`; \ No newline at end of file +ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`; + +ALTER TABLE `chatgpt_chat_history` ADD `total_tokens` INT NOT NULL COMMENT '消耗总Token长度' AFTER `tokens`; \ No newline at end of file