diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 1d68f59a..6be5dc49 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -338,8 +338,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio // Tokens 统计 token 数量 func (h *ChatHandler) Tokens(c *gin.Context) { var data struct { - Text string `json:"text"` - Model string `json:"model"` + Text string `json:"text"` + Model string `json:"model"` + ChatId string `json:"chat_id"` } if err := c.ShouldBindJSON(&data); err != nil { resp.ERROR(c, types.InvalidArgs) @@ -347,10 +348,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) { } // 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文) - if data.Text == "" { + if data.Text == "" && data.ChatId != "" { var item model.HistoryMessage userId, _ := c.Get(types.LoginUserID) - res := h.db.Where("user_id = ?", userId).Last(&item) + res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item) if res.Error != nil { resp.ERROR(c, res.Error.Error()) return diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index ab3c1086..32c809e4 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -634,7 +634,7 @@ const connect = function (chat_id, role_id) { // 获取 token const reply = chatData.value[chatData.value.length - 1] - httpPost("/api/chat/tokens", {text: "", model: getModelValue(modelID.value)}).then(res => { + httpPost("/api/chat/tokens", {text: "", model: getModelValue(modelID.value), chat_id: chat_id}).then(res => { reply['created_at'] = new Date().getTime(); reply['tokens'] = res.data; // 将聊天框的滚动条滑动到最底部