mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-12 04:03:42 +08:00
merge pull request #71
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"html/template"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -57,9 +56,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var functionCall = false
|
||||
var functionName string
|
||||
var arguments = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -69,36 +65,17 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
if err != nil { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
break
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 {
|
||||
continue;
|
||||
}
|
||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
if functionCall && fun.Name == "" {
|
||||
arguments = append(arguments, fun.Arguments)
|
||||
continue
|
||||
}
|
||||
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionName = fun.Name
|
||||
f := h.App.Functions[functionName]
|
||||
if f != nil {
|
||||
functionCall = true
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
break
|
||||
}
|
||||
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
@@ -124,49 +101,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
}
|
||||
}
|
||||
|
||||
if functionCall { // 调用函数完成任务
|
||||
var params map[string]interface{}
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
|
||||
|
||||
// for creating image, check if the user's img_calls > 0
|
||||
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
|
||||
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
} else {
|
||||
f := h.App.Functions[functionName]
|
||||
if functionName == types.FuncImage {
|
||||
params["user_id"] = userVo.Id
|
||||
params["role_id"] = role.Id
|
||||
params["chat_id"] = session.ChatId
|
||||
params["icon"] = "/images/avatar/mid_journey.png"
|
||||
params["session_id"] = session.SessionId
|
||||
}
|
||||
data, err := f.Invoke(params)
|
||||
if err != nil {
|
||||
msg := "调用函数出错:" + err.Error()
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
content := data
|
||||
if functionName == types.FuncImage {
|
||||
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景", params["prompt"])
|
||||
// update user's img_calls
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: content,
|
||||
})
|
||||
contents = append(contents, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
@@ -179,7 +113,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
||||
if h.App.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
@@ -187,11 +121,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
|
||||
// 追加聊天记录
|
||||
if h.App.ChatConfig.EnableHistory {
|
||||
useContext := true
|
||||
if functionCall {
|
||||
useContext = false
|
||||
}
|
||||
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
@@ -205,7 +134,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: useContext,
|
||||
UseContext: true,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
@@ -215,15 +144,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
@@ -234,7 +155,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: useContext,
|
||||
UseContext: true,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
|
||||
Reference in New Issue
Block a user