refactor chat message body struct

This commit is contained in:
RockYang
2024-09-14 07:11:45 +08:00
parent 96b8121210
commit e371310d02
11 changed files with 161 additions and 95 deletions

View File

@@ -73,13 +73,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
modelId := h.GetInt(c, "model_id", 0)
tools := c.Query("tools")
client := types.NewWsClient(ws)
var chatRole model.ChatRole
res := h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
c.Abort()
return
}
@@ -91,7 +90,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
var chatModel model.ChatModel
res = h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭!!!")
utils.ReplyErrorMessage(client, "当前AI模型暂未启用对话已关闭!!!")
c.Abort()
return
}
@@ -100,7 +99,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
SessionId: sessionId,
ClientIP: c.ClientIP(),
UserId: h.GetLoginUserId(c),
Tools: tools,
}
// use old chat data override the chat model and role ID
@@ -137,20 +135,16 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
var message types.WsMessage
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 Chat 心跳消息:", message.Content)
continue
}
logger.Info("Receive a message: ", message.Content)
logger.Infof("Receive a message:%+v", message)
session.Tools = message.Tools
session.Stream = message.Stream
ctx, cancel := context.WithCancel(context.Background())
h.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
@@ -159,7 +153,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
logger.Error(err)
utils.ReplyMessage(client, err.Error())
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
logger.Infof("回答完毕: %v", message.Content)
}
@@ -208,16 +202,21 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
Model: session.Model.Value,
Temperature: session.Model.Temperature,
}
// 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") {
req.MaxCompletionTokens = session.Model.MaxTokens
req.Stream = false
} else {
req.MaxTokens = session.Model.MaxTokens
req.Stream = session.Stream
}
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
if session.Tools != "" {
toolIds := strings.Split(session.Tools, ",")
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
var items []model.Function
res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
@@ -279,7 +278,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
for i := len(messages) - 1; i >= 0; i-- {
v := messages[i]
tks, _ := utils.CalcTokens(v.Content, req.Model)
tks, _ = utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
break
@@ -500,10 +499,17 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
}
}
type Usage struct {
Prompt string
Content string
PromptTokens int
CompletionTokens int
TotalTokens int
}
func (h *ChatHandler) saveChatHistory(
req types.ApiRequest,
prompt string,
contents []string,
usage Usage,
message types.Message,
chatCtx []types.Message,
session *types.ChatSession,
@@ -514,8 +520,8 @@ func (h *ChatHandler) saveChatHistory(
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
message.Content = usage.Content
useMsg := types.Message{Role: "user", Content: usage.Prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext {
@@ -526,42 +532,52 @@ func (h *ChatHandler) saveChatHistory(
// 追加聊天记录
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
var promptTokens, replyTokens, totalTokens int
if usage.PromptTokens > 0 {
promptTokens = usage.PromptTokens
} else {
promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(prompt),
Tokens: promptToken,
UseContext: true,
Model: req.Model,
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(usage.Prompt),
Tokens: promptTokens,
TotalTokens: promptTokens,
UseContext: true,
Model: req.Model,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
err = h.DB.Save(&historyUserMsg).Error
err := h.DB.Save(&historyUserMsg).Error
if err != nil {
logger.Error("failed to save prompt history message: ", err)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyTokens + getTotalTokens(req)
if usage.CompletionTokens > 0 {
replyTokens = usage.CompletionTokens
totalTokens = usage.TotalTokens
} else {
replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
totalTokens = replyTokens + getTotalTokens(req)
}
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
Model: req.Model,
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: replyTokens,
TotalTokens: totalTokens,
UseContext: true,
Model: req.Model,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
@@ -572,7 +588,7 @@ func (h *ChatHandler) saveChatHistory(
// 更新用户算力
if session.Model.Power > 0 {
h.subUserPower(userVo, session, promptToken, replyTokens)
h.subUserPower(userVo, session, promptTokens, replyTokens)
}
// 保存当前会话
var chatItem model.ChatItem
@@ -582,10 +598,10 @@ func (h *ChatHandler) saveChatHistory(
chatItem.UserId = userVo.Id
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
if utf8.RuneCountInString(usage.Prompt) > 30 {
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
chatItem.Title = usage.Prompt
}
chatItem.Model = req.Model
err = h.DB.Create(&chatItem).Error