mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 18:23:45 +08:00
merge pull request #71
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"html/template"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -57,9 +56,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var functionCall = false
|
||||
var functionName string
|
||||
var arguments = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -69,36 +65,17 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
|
||||
var responseBody = types.ApiResponse{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil { // 数据解析出错
|
||||
if err != nil { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
break
|
||||
}
|
||||
|
||||
if len(responseBody.Choices) == 0 {
|
||||
continue;
|
||||
}
|
||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
if functionCall && fun.Name == "" {
|
||||
arguments = append(arguments, fun.Arguments)
|
||||
continue
|
||||
}
|
||||
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionName = fun.Name
|
||||
f := h.App.Functions[functionName]
|
||||
if f != nil {
|
||||
functionCall = true
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
break
|
||||
}
|
||||
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
@@ -124,49 +101,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
}
|
||||
}
|
||||
|
||||
if functionCall { // 调用函数完成任务
|
||||
var params map[string]interface{}
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
|
||||
|
||||
// for creating image, check if the user's img_calls > 0
|
||||
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
|
||||
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
} else {
|
||||
f := h.App.Functions[functionName]
|
||||
if functionName == types.FuncImage {
|
||||
params["user_id"] = userVo.Id
|
||||
params["role_id"] = role.Id
|
||||
params["chat_id"] = session.ChatId
|
||||
params["icon"] = "/images/avatar/mid_journey.png"
|
||||
params["session_id"] = session.SessionId
|
||||
}
|
||||
data, err := f.Invoke(params)
|
||||
if err != nil {
|
||||
msg := "调用函数出错:" + err.Error()
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
content := data
|
||||
if functionName == types.FuncImage {
|
||||
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景", params["prompt"])
|
||||
// update user's img_calls
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: content,
|
||||
})
|
||||
contents = append(contents, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 消息发送成功
|
||||
if len(contents) > 0 {
|
||||
// 更新用户的对话次数
|
||||
@@ -179,7 +113,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
||||
if h.App.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
@@ -187,11 +121,6 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
|
||||
// 追加聊天记录
|
||||
if h.App.ChatConfig.EnableHistory {
|
||||
useContext := true
|
||||
if functionCall {
|
||||
useContext = false
|
||||
}
|
||||
|
||||
// for prompt
|
||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||
if err != nil {
|
||||
@@ -205,7 +134,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(prompt),
|
||||
Tokens: promptToken,
|
||||
UseContext: useContext,
|
||||
UseContext: true,
|
||||
}
|
||||
historyUserMsg.CreatedAt = promptCreatedAt
|
||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||
@@ -215,15 +144,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
}
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
} else {
|
||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
||||
}
|
||||
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||
totalTokens += getTotalTokens(req)
|
||||
|
||||
historyReplyMsg := model.HistoryMessage{
|
||||
@@ -234,7 +155,7 @@ func (h *ChatHandler) sendAzureMessage(
|
||||
Icon: role.Icon,
|
||||
Content: message.Content,
|
||||
Tokens: totalTokens,
|
||||
UseContext: useContext,
|
||||
UseContext: true,
|
||||
}
|
||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||
|
||||
@@ -214,20 +214,45 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
case types.Baidu:
|
||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
||||
// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
|
||||
break
|
||||
case types.OpenAI:
|
||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
||||
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
||||
// OpenAI 支持函数功能
|
||||
if h.App.SysConfig.EnabledFunction {
|
||||
var functions = make([]types.Function, 0)
|
||||
for _, f := range types.InnerFunctions {
|
||||
functions = append(functions, f)
|
||||
var items []model.Function
|
||||
res := h.db.Where("enabled", true).Find(&items)
|
||||
if res.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var tools = make([]interface{}, 0)
|
||||
for _, v := range items {
|
||||
var parameters map[string]interface{}
|
||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
req.Functions = functions
|
||||
required := parameters["required"]
|
||||
delete(parameters, "required")
|
||||
tools = append(tools, gin.H{
|
||||
"type": "function",
|
||||
"function": gin.H{
|
||||
"name": v.Name,
|
||||
"description": v.Description,
|
||||
"parameters": parameters,
|
||||
"required": required,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
req.Tools = tools
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
case types.XunFei:
|
||||
req.Temperature = h.App.ChatConfig.XunFei.Temperature
|
||||
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
|
||||
break
|
||||
default:
|
||||
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
@@ -242,11 +267,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
} else {
|
||||
// calculate the tokens of current request, to prevent to exceeding the max tokens num
|
||||
tokens := req.MaxTokens
|
||||
for _, f := range types.InnerFunctions {
|
||||
tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model)
|
||||
tokens += tks
|
||||
}
|
||||
|
||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||
tokens += tks
|
||||
// loading the role context
|
||||
var messages []types.Message
|
||||
err := utils.JsonDecode(role.Context, &messages)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
req2 "github.com/imroc/req/v3"
|
||||
"html/template"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -56,8 +56,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var functionCall = false
|
||||
var functionName string
|
||||
var function model.Function
|
||||
var toolCall = false
|
||||
var arguments = make([]string, 0)
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
for scanner.Scan() {
|
||||
@@ -75,24 +75,26 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
break
|
||||
}
|
||||
|
||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
||||
if functionCall && fun.Name == "" {
|
||||
arguments = append(arguments, fun.Arguments)
|
||||
continue
|
||||
var fun types.ToolCall
|
||||
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
||||
fun = responseBody.Choices[0].Delta.ToolCalls[0]
|
||||
if toolCall && fun.Function.Name == "" {
|
||||
arguments = append(arguments, fun.Function.Arguments)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if !utils.IsEmptyValue(fun) {
|
||||
functionName = fun.Name
|
||||
f := h.App.Functions[functionName]
|
||||
if f != nil {
|
||||
functionCall = true
|
||||
res := h.db.Where("name = ?", fun.Function.Name).First(&function)
|
||||
if res.Error == nil {
|
||||
toolCall = true
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
||||
if responseBody.Choices[0].FinishReason == "tool_calls" { // 函数调用完毕
|
||||
break
|
||||
}
|
||||
|
||||
@@ -121,47 +123,35 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
}
|
||||
}
|
||||
|
||||
if functionCall { // 调用函数完成任务
|
||||
if toolCall { // 调用函数完成任务
|
||||
var params map[string]interface{}
|
||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
|
||||
|
||||
// for creating image, check if the user's img_calls > 0
|
||||
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
|
||||
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||
params["user_id"] = userVo.Id
|
||||
var apiRes types.BizVo
|
||||
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", function.Token).
|
||||
SetBody(params).
|
||||
SetSuccessResult(&apiRes).Post(function.Action)
|
||||
errMsg := ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else if r.IsErrorState() {
|
||||
errMsg = r.Err.Error()
|
||||
}
|
||||
if errMsg != "" || apiRes.Code != types.Success {
|
||||
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
f := h.App.Functions[functionName]
|
||||
// translate prompt
|
||||
if functionName == types.FuncImage {
|
||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
r, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
|
||||
if err == nil {
|
||||
params["prompt"] = r
|
||||
}
|
||||
}
|
||||
data, err := f.Invoke(params)
|
||||
if err != nil {
|
||||
msg := "调用函数出错:" + err.Error()
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: msg,
|
||||
})
|
||||
contents = append(contents, msg)
|
||||
} else {
|
||||
content := data
|
||||
if functionName == types.FuncImage {
|
||||
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景。%s", params["prompt"], data)
|
||||
// update user's img_calls
|
||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: content,
|
||||
})
|
||||
contents = append(contents, content)
|
||||
}
|
||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: apiRes.Data,
|
||||
})
|
||||
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +167,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
|
||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
||||
if h.App.ChatConfig.EnableContext && toolCall == false {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
@@ -186,7 +176,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
// 追加聊天记录
|
||||
if h.App.ChatConfig.EnableHistory {
|
||||
useContext := true
|
||||
if functionCall {
|
||||
if toolCall {
|
||||
useContext = false
|
||||
}
|
||||
|
||||
@@ -214,8 +204,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
|
||||
// 计算本次对话消耗的总 token 数量
|
||||
var totalTokens = 0
|
||||
if functionCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
||||
if toolCall { // prompt + 函数名 + 参数 token
|
||||
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
||||
totalTokens += tokens
|
||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||
totalTokens += tokens
|
||||
|
||||
Reference in New Issue
Block a user