fix: function call 兼容中转 API

This commit is contained in:
RockYang 2024-01-07 22:32:59 +08:00
parent 7000168fd4
commit 485bdbc56a
4 changed files with 80 additions and 21 deletions

View File

@ -9,7 +9,9 @@ type ApiRequest struct {
Messages []interface{} `json:"messages,omitempty"` Messages []interface{} `json:"messages,omitempty"`
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
Tools []interface{} `json:"tools,omitempty"` Tools []interface{} `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"` Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台
ToolChoice string `json:"tool_choice,omitempty"`
} }
type Message struct { type Message struct {
@ -28,10 +30,14 @@ type ChoiceItem struct {
} }
type Delta struct { type Delta struct {
Role string `json:"role"` Role string `json:"role"`
Name string `json:"name"` Name string `json:"name"`
Content interface{} `json:"content"` Content interface{} `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FunctionCall struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
} `json:"function_call,omitempty"`
} }
// ChatSession 聊天会话对象 // ChatSession 聊天会话对象

View File

@ -229,6 +229,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
var tools = make([]interface{}, 0) var tools = make([]interface{}, 0)
var functions = make([]interface{}, 0)
for _, v := range items { for _, v := range items {
var parameters map[string]interface{} var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters) err = utils.JsonDecode(v.Parameters, &parameters)
@ -246,12 +247,22 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"required": required, "required": required,
}, },
}) })
functions = append(functions, gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
})
} }
if len(tools) > 0 { //if len(tools) > 0 {
req.Tools = tools // req.Tools = tools
req.ToolChoice = "auto" // req.ToolChoice = "auto"
//}
if len(functions) > 0 {
req.Functions = functions
} }
case types.XunFei: case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
@ -438,6 +449,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token) apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
} }
logger.Debugf(utils.JsonEncode(req))
// 创建 HttpClient 请求对象 // 创建 HttpClient 请求对象
var client *http.Client var client *http.Client
requestBody, err := json.Marshal(req) requestBody, err := json.Marshal(req)

View File

@ -76,17 +76,27 @@ func (h *ChatHandler) sendOpenAiMessage(
break break
} }
var fun types.ToolCall var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
fun = responseBody.Choices[0].Delta.ToolCalls[0] tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && fun.Function.Name == "" { if toolCall && tool.Function.Name == "" {
arguments = append(arguments, fun.Function.Arguments) arguments = append(arguments, tool.Function.Arguments)
continue continue
} }
} }
if !utils.IsEmptyValue(fun) { // 兼容 Function Call
res := h.db.Where("name = ?", fun.Function.Name).First(&function) fun := responseBody.Choices[0].Delta.FunctionCall
if fun.Name != "" {
tool = *new(types.ToolCall)
tool.Function.Name = fun.Name
} else if toolCall {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(tool) {
res := h.db.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil { if res.Error == nil {
toolCall = true toolCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
@ -95,7 +105,8 @@ func (h *ChatHandler) sendOpenAiMessage(
continue continue
} }
if responseBody.Choices[0].FinishReason == "tool_calls" { // 函数调用完毕 if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break break
} }

View File

@ -58,6 +58,10 @@ WeChatBot = false
BotToken = "" BotToken = ""
GuildId = "" GuildId = ""
ChanelId = "" ChanelId = ""
UseCDN = false #是否使用反向代理访问设置为true下面的设置才会生效
DiscordAPI = "https://mj.r9it.com:8001" # discord API 反代地址
DiscordCDN = "https://mj.r9it.com:8002" # mj 图片反代地址
DiscordGateway = "wss://mj.r9it.com:8003" # discord 机器人反代地址
[[MjConfigs]] [[MjConfigs]]
Enabled = false Enabled = false
@ -65,6 +69,16 @@ WeChatBot = false
BotToken = "" BotToken = ""
GuildId = "" GuildId = ""
ChanelId = "" ChanelId = ""
UseCDN = false #是否使用反向代理访问设置为true下面的设置才会生效
DiscordAPI = "https://mj.r9it.com:8001" # discord API 反代地址
DiscordCDN = "https://mj.r9it.com:8002" # mj 图片反代地址
DiscordGateway = "wss://mj.r9it.com:8003" # discord 机器人反代地址
[[SdConfigs]]
Enabled = false
ApiURL = ""
ApiKey = ""
Txt2ImgJsonPath = "res/sd/text2img.json"
[[SdConfigs]] [[SdConfigs]]
Enabled = false Enabled = false
@ -80,7 +94,7 @@ WeChatBot = false
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动 [XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP如果你没有启用支付服务则该服务也无需启动
Enabled = false # 是否启用 XXL JOB 服务 Enabled = false # 是否启用 XXL JOB 服务
ServerAddr = "http://chatgpt-plus-xxl-job:8080/xxl-job-admin" # xxl-job-admin 管理地址 ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
ExecutorIp = "172.22.11.47" # 执行器 IP 地址 ExecutorIp = "172.22.11.47" # 执行器 IP 地址
ExecutorPort = "9999" # 执行器服务端口 ExecutorPort = "9999" # 执行器服务端口
AccessToken = "xxl-job-api-token" # 执行器 API 通信 token AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
@ -95,12 +109,27 @@ WeChatBot = false
PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书 PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书 AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书 RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
NotifyURL = "http://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址 NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
[HuPiPayConfig] # 虎皮椒支付配置 [HuPiPayConfig]
Enabled = false Enabled = false
Name = "wechat" Name = "wechat"
AppId = "" AppId = "201906161477"
AppSecret = "" AppSecret = "7f403199d510fb2c6f0b9f2311800e7c"
PayURL = "https://api.xunhupay.com/payment/do.html" PayURL = "https://api.xunhupay.com/payment/do.html"
NotifyURL = "http://ai.r9it.com/api/payment/hupipay/notify" NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
[SmtpConfig] # 注意阿里云服务器禁用了25号端口所以如果需要使用邮件功能请别用阿里云服务器
Host = "smtp.163.com"
Port = 25
AppName = "极客学长"
From = "test@163.com" # 发件邮箱人地址
Password = "" #邮箱 stmp 服务授权码
[JPayConfig] # PayJs 支付配置
Enabled = false
Name = "wechat" # 请不要改动
AppId = "" # 商户 ID
PrivateKey = "" # 秘钥
ApiURL = "https://payjs.cn/api/native"
NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的