mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
opt: optimize the styles of chat page; caculate all tokens of context as chat history's token
This commit is contained in:
@@ -220,17 +220,17 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var replyToken = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
replyToken += tokens
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
replyToken += tokens
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
totalTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
@@ -239,7 +239,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: replyToken,
|
||||
Tokens: totalTokens,
|
||||
UseContext: useContext,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
@@ -249,13 +249,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
totalTokens = promptToken + replyToken
|
||||
} else {
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
}
|
||||
// 更新用户信息
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||
}
|
||||
|
||||
@@ -291,6 +291,19 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
|
||||
if data.Text == "" {
|
||||
var item model.HistoryMessage
|
||||
userId, _ := c.Get(types.LoginUserID)
|
||||
res := h.db.Where("user_id = ?", userId).Last(&item)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c, item.Tokens)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := utils.CalcTokens(data.Text, data.Model)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
|
||||
@@ -146,9 +146,8 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var replyToken = 0
|
||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
|
||||
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens := replyToken + getTotalTokens(req)
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
@@ -156,7 +155,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: replyToken,
|
||||
Tokens: totalTokens,
|
||||
UseContext: true,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
@@ -165,10 +164,7 @@ func (h *ChatHandler) sendChatGLMMessage(
|
||||
if res.Error != nil {
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
// 更新用户信息
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||
}
|
||||
|
||||
@@ -95,6 +95,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
}
|
||||
|
||||
data.Key = utils.Sha256(data.Prompt)
|
||||
wsClient := h.App.MjTaskClients.Get(data.Key)
|
||||
//logger.Info(data.Prompt, ",", key)
|
||||
if data.Status == Finished {
|
||||
var task types.MjTask
|
||||
@@ -104,10 +105,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
resp.SUCCESS(c)
|
||||
return
|
||||
}
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
// download image
|
||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -144,8 +153,6 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 推送消息到客户端
|
||||
wsClient := h.App.MjTaskClients.Get(data.Key)
|
||||
if wsClient == nil { // 客户端断线,则丢弃
|
||||
logger.Errorf("Client is offline: %+v", data)
|
||||
resp.SUCCESS(c, "Client is offline")
|
||||
|
||||
@@ -220,17 +220,17 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
logger.Error("failed to save prompt history message: ", res.Error)
|
||||
}
|
||||
|
||||
// for reply
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var replyToken = 0
|
||||
if functionCall { // 函数名 + 参数 token
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
replyToken += tokens
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
replyToken += tokens
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
replyToken, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
totalTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
UserId: userVo.Id,
|
||||
@@ -239,7 +239,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: replyToken,
|
||||
Tokens: totalTokens,
|
||||
UseContext: useContext,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
@@ -249,13 +249,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
logger.Error("failed to save reply history message: ", res.Error)
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
totalTokens = promptToken + replyToken
|
||||
} else {
|
||||
totalTokens = replyToken + getTotalTokens(req)
|
||||
}
|
||||
// 更新用户信息
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
|
||||
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user