mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
feat: function manager refactor is ready
This commit is contained in:
parent
b1ab9975b7
commit
f603bf6be7
@ -159,6 +159,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
|||||||
c.Request.URL.Path == "/api/sd/jobs" ||
|
c.Request.URL.Path == "/api/sd/jobs" ||
|
||||||
c.Request.URL.Path == "/api/upload" ||
|
c.Request.URL.Path == "/api/upload" ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/test/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/test/") ||
|
||||||
|
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
|
strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
|
||||||
|
@ -8,7 +8,8 @@ type ApiRequest struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Messages []interface{} `json:"messages,omitempty"`
|
Messages []interface{} `json:"messages,omitempty"`
|
||||||
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
|
||||||
Functions []Function `json:"functions,omitempty"`
|
Tools []interface{} `json:"tools,omitempty"`
|
||||||
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@ -30,7 +31,7 @@ type Delta struct {
|
|||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Content interface{} `json:"content"`
|
Content interface{} `json:"content"`
|
||||||
FunctionCall FunctionCall `json:"function_call,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatSession 聊天会话对象
|
// ChatSession 聊天会话对象
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
type FunctionCall struct {
|
type ToolCall struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments string `json:"arguments"`
|
Arguments string `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
|
@ -39,7 +39,6 @@ func (h *FunctionHandler) Save(c *gin.Context) {
|
|||||||
Label: data.Label,
|
Label: data.Label,
|
||||||
Description: data.Description,
|
Description: data.Description,
|
||||||
Parameters: utils.JsonEncode(data.Parameters),
|
Parameters: utils.JsonEncode(data.Parameters),
|
||||||
Required: utils.JsonEncode(data.Required),
|
|
||||||
Action: data.Action,
|
Action: data.Action,
|
||||||
Token: data.Token,
|
Token: data.Token,
|
||||||
Enabled: data.Enabled,
|
Enabled: data.Enabled,
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gorm.io/gorm"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
@ -57,9 +56,6 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
// 循环读取 Chunk 消息
|
// 循环读取 Chunk 消息
|
||||||
var message = types.Message{}
|
var message = types.Message{}
|
||||||
var contents = make([]string, 0)
|
var contents = make([]string, 0)
|
||||||
var functionCall = false
|
|
||||||
var functionName string
|
|
||||||
var arguments = make([]string, 0)
|
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
@ -76,27 +72,6 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
|
||||||
if functionCall && fun.Name == "" {
|
|
||||||
arguments = append(arguments, fun.Arguments)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !utils.IsEmptyValue(fun) {
|
|
||||||
functionName = fun.Name
|
|
||||||
f := h.App.Functions[functionName]
|
|
||||||
if f != nil {
|
|
||||||
functionCall = true
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化 role
|
// 初始化 role
|
||||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||||
message.Role = responseBody.Choices[0].Delta.Role
|
message.Role = responseBody.Choices[0].Delta.Role
|
||||||
@ -122,49 +97,6 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if functionCall { // 调用函数完成任务
|
|
||||||
var params map[string]interface{}
|
|
||||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
|
||||||
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
|
|
||||||
|
|
||||||
// for creating image, check if the user's img_calls > 0
|
|
||||||
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
|
|
||||||
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
|
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
|
||||||
} else {
|
|
||||||
f := h.App.Functions[functionName]
|
|
||||||
if functionName == types.FuncImage {
|
|
||||||
params["user_id"] = userVo.Id
|
|
||||||
params["role_id"] = role.Id
|
|
||||||
params["chat_id"] = session.ChatId
|
|
||||||
params["icon"] = "/images/avatar/mid_journey.png"
|
|
||||||
params["session_id"] = session.SessionId
|
|
||||||
}
|
|
||||||
data, err := f.Invoke(params)
|
|
||||||
if err != nil {
|
|
||||||
msg := "调用函数出错:" + err.Error()
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: msg,
|
|
||||||
})
|
|
||||||
contents = append(contents, msg)
|
|
||||||
} else {
|
|
||||||
content := data
|
|
||||||
if functionName == types.FuncImage {
|
|
||||||
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景", params["prompt"])
|
|
||||||
// update user's img_calls
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
|
||||||
Type: types.WsMiddle,
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
contents = append(contents, content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
// 更新用户的对话次数
|
// 更新用户的对话次数
|
||||||
@ -177,7 +109,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
if h.App.ChatConfig.EnableContext {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
@ -185,11 +117,6 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
|
|
||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
if h.App.ChatConfig.EnableHistory {
|
||||||
useContext := true
|
|
||||||
if functionCall {
|
|
||||||
useContext = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// for prompt
|
// for prompt
|
||||||
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
promptToken, err := utils.CalcTokens(prompt, req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -203,7 +130,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
Icon: userVo.Avatar,
|
Icon: userVo.Avatar,
|
||||||
Content: template.HTMLEscapeString(prompt),
|
Content: template.HTMLEscapeString(prompt),
|
||||||
Tokens: promptToken,
|
Tokens: promptToken,
|
||||||
UseContext: useContext,
|
UseContext: true,
|
||||||
}
|
}
|
||||||
historyUserMsg.CreatedAt = promptCreatedAt
|
historyUserMsg.CreatedAt = promptCreatedAt
|
||||||
historyUserMsg.UpdatedAt = promptCreatedAt
|
historyUserMsg.UpdatedAt = promptCreatedAt
|
||||||
@ -213,15 +140,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
var totalTokens = 0
|
totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
|
||||||
if functionCall { // prompt + 函数名 + 参数 token
|
|
||||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
|
||||||
totalTokens += tokens
|
|
||||||
} else {
|
|
||||||
totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
|
|
||||||
}
|
|
||||||
totalTokens += getTotalTokens(req)
|
totalTokens += getTotalTokens(req)
|
||||||
|
|
||||||
historyReplyMsg := model.HistoryMessage{
|
historyReplyMsg := model.HistoryMessage{
|
||||||
@ -232,7 +151,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
Tokens: totalTokens,
|
Tokens: totalTokens,
|
||||||
UseContext: useContext,
|
UseContext: true,
|
||||||
}
|
}
|
||||||
historyReplyMsg.CreatedAt = replyCreatedAt
|
historyReplyMsg.CreatedAt = replyCreatedAt
|
||||||
historyReplyMsg.UpdatedAt = replyCreatedAt
|
historyReplyMsg.UpdatedAt = replyCreatedAt
|
||||||
|
@ -214,20 +214,45 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
case types.Baidu:
|
case types.Baidu:
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
||||||
// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
|
// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
|
||||||
|
break
|
||||||
case types.OpenAI:
|
case types.OpenAI:
|
||||||
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
||||||
// OpenAI 支持函数功能
|
// OpenAI 支持函数功能
|
||||||
if h.App.SysConfig.EnabledFunction {
|
var items []model.Function
|
||||||
var functions = make([]types.Function, 0)
|
res := h.db.Where("enabled", true).Find(&items)
|
||||||
for _, f := range types.InnerFunctions {
|
if res.Error != nil {
|
||||||
functions = append(functions, f)
|
break
|
||||||
}
|
}
|
||||||
req.Functions = functions
|
|
||||||
|
var tools = make([]interface{}, 0)
|
||||||
|
for _, v := range items {
|
||||||
|
var parameters map[string]interface{}
|
||||||
|
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
required := parameters["required"]
|
||||||
|
delete(parameters, "required")
|
||||||
|
tools = append(tools, gin.H{
|
||||||
|
"type": "function",
|
||||||
|
"function": gin.H{
|
||||||
|
"name": v.Name,
|
||||||
|
"description": v.Description,
|
||||||
|
"parameters": parameters,
|
||||||
|
"required": required,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
req.Tools = tools
|
||||||
|
req.ToolChoice = "auto"
|
||||||
}
|
}
|
||||||
case types.XunFei:
|
case types.XunFei:
|
||||||
req.Temperature = h.App.ChatConfig.XunFei.Temperature
|
req.Temperature = h.App.ChatConfig.XunFei.Temperature
|
||||||
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
|
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
|
||||||
|
break
|
||||||
default:
|
default:
|
||||||
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gorm.io/gorm"
|
req2 "github.com/imroc/req/v3"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
@ -56,8 +56,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
// 循环读取 Chunk 消息
|
// 循环读取 Chunk 消息
|
||||||
var message = types.Message{}
|
var message = types.Message{}
|
||||||
var contents = make([]string, 0)
|
var contents = make([]string, 0)
|
||||||
var functionCall = false
|
var function model.Function
|
||||||
var functionName string
|
var toolCall = false
|
||||||
var arguments = make([]string, 0)
|
var arguments = make([]string, 0)
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
@ -75,24 +75,26 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
fun := responseBody.Choices[0].Delta.FunctionCall
|
var fun types.ToolCall
|
||||||
if functionCall && fun.Name == "" {
|
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
|
||||||
arguments = append(arguments, fun.Arguments)
|
fun = responseBody.Choices[0].Delta.ToolCalls[0]
|
||||||
|
if toolCall && fun.Function.Name == "" {
|
||||||
|
arguments = append(arguments, fun.Function.Arguments)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !utils.IsEmptyValue(fun) {
|
if !utils.IsEmptyValue(fun) {
|
||||||
functionName = fun.Name
|
res := h.db.Where("name = ?", fun.Function.Name).First(&function)
|
||||||
f := h.App.Functions[functionName]
|
if res.Error == nil {
|
||||||
if f != nil {
|
toolCall = true
|
||||||
functionCall = true
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
|
if responseBody.Choices[0].FinishReason == "tool_calls" { // 函数调用完毕
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,47 +123,35 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if functionCall { // 调用函数完成任务
|
if toolCall { // 调用函数完成任务
|
||||||
var params map[string]interface{}
|
var params map[string]interface{}
|
||||||
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
|
||||||
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
|
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
|
||||||
|
params["user_id"] = userVo.Id
|
||||||
// for creating image, check if the user's img_calls > 0
|
var apiRes types.BizVo
|
||||||
if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
|
r, err := req2.C().R().SetHeader("Content-Type", "application/json").
|
||||||
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
|
SetHeader("Authorization", function.Token).
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
SetBody(params).
|
||||||
} else {
|
SetSuccessResult(&apiRes).Post(function.Action)
|
||||||
f := h.App.Functions[functionName]
|
errMsg := ""
|
||||||
// translate prompt
|
|
||||||
if functionName == types.FuncImage {
|
|
||||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
|
||||||
r, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
|
|
||||||
if err == nil {
|
|
||||||
params["prompt"] = r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
data, err := f.Invoke(params)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := "调用函数出错:" + err.Error()
|
errMsg = err.Error()
|
||||||
|
} else if r.IsErrorState() {
|
||||||
|
errMsg = r.Err.Error()
|
||||||
|
}
|
||||||
|
if errMsg != "" || apiRes.Code != types.Success {
|
||||||
|
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: msg,
|
Content: msg,
|
||||||
})
|
})
|
||||||
contents = append(contents, msg)
|
contents = append(contents, msg)
|
||||||
} else {
|
} else {
|
||||||
content := data
|
|
||||||
if functionName == types.FuncImage {
|
|
||||||
content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景。%s", params["prompt"], data)
|
|
||||||
// update user's img_calls
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: content,
|
Content: apiRes.Data,
|
||||||
})
|
})
|
||||||
contents = append(contents, content)
|
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,7 +167,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
useMsg := types.Message{Role: "user", Content: prompt}
|
useMsg := types.Message{Role: "user", Content: prompt}
|
||||||
|
|
||||||
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
// 更新上下文消息,如果是调用函数则不需要更新上下文
|
||||||
if h.App.ChatConfig.EnableContext && functionCall == false {
|
if h.App.ChatConfig.EnableContext && toolCall == false {
|
||||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||||
chatCtx = append(chatCtx, message) // 回复消息
|
chatCtx = append(chatCtx, message) // 回复消息
|
||||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||||
@ -186,7 +176,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
// 追加聊天记录
|
// 追加聊天记录
|
||||||
if h.App.ChatConfig.EnableHistory {
|
if h.App.ChatConfig.EnableHistory {
|
||||||
useContext := true
|
useContext := true
|
||||||
if functionCall {
|
if toolCall {
|
||||||
useContext = false
|
useContext = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,8 +204,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
// 计算本次对话消耗的总 token 数量
|
// 计算本次对话消耗的总 token 数量
|
||||||
var totalTokens = 0
|
var totalTokens = 0
|
||||||
if functionCall { // prompt + 函数名 + 参数 token
|
if toolCall { // prompt + 函数名 + 参数 token
|
||||||
tokens, _ := utils.CalcTokens(functionName, req.Model)
|
tokens, _ := utils.CalcTokens(function.Name, req.Model)
|
||||||
totalTokens += tokens
|
totalTokens += tokens
|
||||||
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
|
||||||
totalTokens += tokens
|
totalTokens += tokens
|
||||||
|
235
api/handler/function_handler.go
Normal file
235
api/handler/function_handler.go
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core"
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/oss"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/utils"
|
||||||
|
"chatplus/utils/resp"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FunctionHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
db *gorm.DB
|
||||||
|
config types.ChatPlusApiConfig
|
||||||
|
uploadManager *oss.UploaderManager
|
||||||
|
proxyURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
|
||||||
|
return &FunctionHandler{
|
||||||
|
BaseHandler: BaseHandler{
|
||||||
|
App: server,
|
||||||
|
},
|
||||||
|
db: db,
|
||||||
|
config: config.ApiConfig,
|
||||||
|
uploadManager: manager,
|
||||||
|
proxyURL: config.ProxyURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type resVo struct {
|
||||||
|
Code types.BizCode `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
Items []dataItem `json:"items"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type dataItem struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Url string `json:"url"`
|
||||||
|
Remark string `json:"remark"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WeiBo 微博热搜
|
||||||
|
func (h *FunctionHandler) WeiBo(c *gin.Context) {
|
||||||
|
if h.config.Token == "" {
|
||||||
|
resp.ERROR(c, "无效的 API Token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
|
||||||
|
var res resVo
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("AppId", h.config.AppId).
|
||||||
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||||
|
SetSuccessResult(&res).Get(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
resp.ERROR(c, res.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := make([]string, 0)
|
||||||
|
builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
|
||||||
|
for i, v := range res.Data.Items {
|
||||||
|
builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
|
||||||
|
}
|
||||||
|
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZaoBao 今日早报
|
||||||
|
func (h *FunctionHandler) ZaoBao(c *gin.Context) {
|
||||||
|
if h.config.Token == "" {
|
||||||
|
resp.ERROR(c, "无效的 API Token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
|
||||||
|
var res resVo
|
||||||
|
r, err := req.C().R().
|
||||||
|
SetHeader("AppId", h.config.AppId).
|
||||||
|
SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
|
||||||
|
SetSuccessResult(&res).Get(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != types.Success {
|
||||||
|
resp.ERROR(c, res.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := make([]string, 0)
|
||||||
|
builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
|
||||||
|
for _, v := range res.Data.Items {
|
||||||
|
builder = append(builder, v.Title)
|
||||||
|
}
|
||||||
|
builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
|
||||||
|
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type imgReq struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type imgRes struct {
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Data []struct {
|
||||||
|
RevisedPrompt string `json:"revised_prompt"`
|
||||||
|
Url string `json:"url"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrRes struct {
|
||||||
|
Error struct {
|
||||||
|
Code interface{} `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Param interface{} `json:"param"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dall3 DallE3 AI 绘图
|
||||||
|
func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||||
|
var params map[string]interface{}
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("绘画参数:%+v", params)
|
||||||
|
// check img calls
|
||||||
|
var user model.User
|
||||||
|
tx := h.db.Where("id = ?", params["user_id"]).First(&user)
|
||||||
|
if tx.Error != nil {
|
||||||
|
resp.ERROR(c, "当前用户不存在!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.ImgCalls <= 0 {
|
||||||
|
resp.ERROR(c, "当前用户的绘图次数额度不足!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
|
// get image generation API KEY
|
||||||
|
var apiKey model.ApiKey
|
||||||
|
tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
|
||||||
|
if tx.Error != nil {
|
||||||
|
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// get image generation api URL
|
||||||
|
var conf model.Config
|
||||||
|
var chatConfig types.ChatConfig
|
||||||
|
tx = h.db.Where("marker", "chat").First(&conf)
|
||||||
|
if tx.Error != nil {
|
||||||
|
resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := utils.JsonDecode(conf.Config, &chatConfig)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "error with decode chat config: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate prompt
|
||||||
|
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||||
|
pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
|
||||||
|
if err == nil {
|
||||||
|
prompt = pt
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL := chatConfig.DallApiURL
|
||||||
|
if utils.IsEmptyValue(apiURL) {
|
||||||
|
apiURL = "https://api.openai.com/v1/images/generations"
|
||||||
|
}
|
||||||
|
imgNum := chatConfig.DallImgNum
|
||||||
|
if imgNum <= 0 {
|
||||||
|
imgNum = 1
|
||||||
|
}
|
||||||
|
var res imgRes
|
||||||
|
var errRes ErrRes
|
||||||
|
var request *req.Request
|
||||||
|
if strings.Contains(apiURL, "api.openai.com") {
|
||||||
|
request = req.C().SetProxyURL(h.proxyURL).R()
|
||||||
|
} else {
|
||||||
|
request = req.C().R()
|
||||||
|
}
|
||||||
|
r, err := request.SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||||
|
SetBody(imgReq{
|
||||||
|
Model: "dall-e-3",
|
||||||
|
Prompt: prompt,
|
||||||
|
N: imgNum,
|
||||||
|
Size: "1024x1024",
|
||||||
|
}).
|
||||||
|
SetErrorResult(&errRes).
|
||||||
|
SetSuccessResult(&res).Post(apiURL)
|
||||||
|
if r.IsErrorState() {
|
||||||
|
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 存储图片
|
||||||
|
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
|
||||||
|
if err != nil {
|
||||||
|
resp.ERROR(c, "下载图片失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL)
|
||||||
|
// update user's img_calls
|
||||||
|
h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
|
|
||||||
|
resp.SUCCESS(c, content)
|
||||||
|
}
|
@ -358,6 +358,14 @@ func main() {
|
|||||||
group.GET("token", h.GenToken)
|
group.GET("token", h.GenToken)
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
fx.Provide(handler.NewFunctionHandler),
|
||||||
|
fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
|
||||||
|
group := s.Engine.Group("/api/function/")
|
||||||
|
group.POST("weibo", h.WeiBo)
|
||||||
|
group.POST("zaobao", h.ZaoBao)
|
||||||
|
group.POST("dalle3", h.Dall3)
|
||||||
|
}),
|
||||||
|
|
||||||
fx.Provide(handler.NewTestHandler),
|
fx.Provide(handler.NewTestHandler),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
|
||||||
group := s.Engine.Group("/test/")
|
group := s.Engine.Group("/test/")
|
||||||
|
@ -6,7 +6,6 @@ type Function struct {
|
|||||||
Label string
|
Label string
|
||||||
Description string
|
Description string
|
||||||
Parameters string
|
Parameters string
|
||||||
Required string
|
|
||||||
Action string
|
Action string
|
||||||
Token string
|
Token string
|
||||||
Enabled bool
|
Enabled bool
|
||||||
|
@ -2,7 +2,7 @@ package vo
|
|||||||
|
|
||||||
type Parameters struct {
|
type Parameters struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Required []string `json:"required"`
|
Required []string `json:"required,omitempty"`
|
||||||
Properties map[string]Property `json:"properties"`
|
Properties map[string]Property `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -17,7 +17,6 @@ type Function struct {
|
|||||||
Label string `json:"label"`
|
Label string `json:"label"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Parameters Parameters `json:"parameters"`
|
Parameters Parameters `json:"parameters"`
|
||||||
Required []string `json:"required"`
|
|
||||||
Action string `json:"action"`
|
Action string `json:"action"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
@ -24,3 +24,5 @@ ALTER TABLE `chatgpt_mj_jobs` ADD `use_proxy` TINYINT(1) NOT NULL DEFAULT '0' CO
|
|||||||
ALTER TABLE `chatgpt_mj_jobs` CHANGE `img_url` `img_url` VARCHAR(400) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '图片URL';
|
ALTER TABLE `chatgpt_mj_jobs` CHANGE `img_url` `img_url` VARCHAR(400) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '图片URL';
|
||||||
|
|
||||||
ALTER TABLE `chatgpt_functions` ADD `token` VARCHAR(255) NULL COMMENT 'API授权token' AFTER `action`;
|
ALTER TABLE `chatgpt_functions` ADD `token` VARCHAR(255) NULL COMMENT 'API授权token' AFTER `action`;
|
||||||
|
|
||||||
|
ALTER TABLE `chatgpt_functions` DROP `required`;
|
||||||
|
@ -238,7 +238,7 @@ const save = function () {
|
|||||||
required.push(params.value[i].name)
|
required.push(params.value[i].name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
item.value.parameters = {type: "object", "properties": properties, "required": required}
|
item.value.parameters = {type: "object", "properties": properties, required: required}
|
||||||
httpPost('/api/admin/function/save', item.value).then((res) => {
|
httpPost('/api/admin/function/save', item.value).then((res) => {
|
||||||
ElMessage.success('操作成功')
|
ElMessage.success('操作成功')
|
||||||
console.log(res.data)
|
console.log(res.data)
|
||||||
|
Loading…
Reference in New Issue
Block a user