mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	merge pull request #71
This commit is contained in:
		@@ -3,7 +3,6 @@ package core
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/fun"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
@@ -39,10 +38,9 @@ type AppServer struct {
 | 
			
		||||
	ChatSession   *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合
 | 
			
		||||
	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
			
		||||
	Functions     map[string]fun.Function
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer {
 | 
			
		||||
func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
			
		||||
	gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	gin.DefaultWriter = io.Discard
 | 
			
		||||
	return &AppServer{
 | 
			
		||||
@@ -53,7 +51,6 @@ func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *A
 | 
			
		||||
		ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
		Functions:     functions,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -159,6 +156,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
			
		||||
			c.Request.URL.Path == "/api/sd/jobs" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/upload" ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/test/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,8 @@ type ApiRequest struct {
 | 
			
		||||
	Stream      bool          `json:"stream"`
 | 
			
		||||
	Messages    []interface{} `json:"messages,omitempty"`
 | 
			
		||||
	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
 | 
			
		||||
	Functions   []Function    `json:"functions,omitempty"`
 | 
			
		||||
	Tools       []interface{} `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice  string        `json:"tool_choice,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
@@ -27,10 +28,10 @@ type ChoiceItem struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Delta struct {
 | 
			
		||||
	Role         string       `json:"role"`
 | 
			
		||||
	Name         string       `json:"name"`
 | 
			
		||||
	Content      interface{}  `json:"content"`
 | 
			
		||||
	FunctionCall FunctionCall `json:"function_call,omitempty"`
 | 
			
		||||
	Role      string      `json:"role"`
 | 
			
		||||
	Name      string      `json:"name"`
 | 
			
		||||
	Content   interface{} `json:"content"`
 | 
			
		||||
	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatSession 聊天会话对象
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,11 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
type FunctionCall struct {
 | 
			
		||||
	Name      string `json:"name"`
 | 
			
		||||
	Arguments string `json:"arguments"`
 | 
			
		||||
type ToolCall struct {
 | 
			
		||||
	Type     string `json:"type"`
 | 
			
		||||
	Function struct {
 | 
			
		||||
		Name      string `json:"name"`
 | 
			
		||||
		Arguments string `json:"arguments"`
 | 
			
		||||
	} `json:"function"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Function struct {
 | 
			
		||||
@@ -21,47 +24,3 @@ type Property struct {
 | 
			
		||||
	Type        string `json:"type"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	FuncZaoBao   = "zao_bao"    // 每日早报
 | 
			
		||||
	FuncHeadLine = "headline"   // 今日头条
 | 
			
		||||
	FuncWeibo    = "weibo_hot"  // 微博热搜
 | 
			
		||||
	FuncImage    = "draw_image" // AI 绘画
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var InnerFunctions = []Function{
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncZaoBao,
 | 
			
		||||
		Description: "每日早报,获取当天新闻事件列表",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
 | 
			
		||||
			Type:       "object",
 | 
			
		||||
			Properties: map[string]Property{},
 | 
			
		||||
			Required:   []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncWeibo,
 | 
			
		||||
		Description: "新浪微博热搜榜,微博当日热搜榜单",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type:       "object",
 | 
			
		||||
			Properties: map[string]Property{},
 | 
			
		||||
			Required:   []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncImage,
 | 
			
		||||
		Description: "AI 绘画工具,根据输入的绘图描述用 AI 工具进行绘画",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"prompt": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "提示词,请自动将该参数翻译成英文。",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,6 @@ func (h *FunctionHandler) Save(c *gin.Context) {
 | 
			
		||||
		Label:       data.Label,
 | 
			
		||||
		Description: data.Description,
 | 
			
		||||
		Parameters:  utils.JsonEncode(data.Parameters),
 | 
			
		||||
		Required:    utils.JsonEncode(data.Required),
 | 
			
		||||
		Action:      data.Action,
 | 
			
		||||
		Token:       data.Token,
 | 
			
		||||
		Enabled:     data.Enabled,
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,6 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -57,9 +56,6 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var functionCall = false
 | 
			
		||||
		var functionName string
 | 
			
		||||
		var arguments = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
@@ -69,36 +65,17 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil  { // 数据解析出错
 | 
			
		||||
			if err != nil { // 数据解析出错
 | 
			
		||||
				logger.Error(err, line)
 | 
			
		||||
				utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(responseBody.Choices) == 0 {
 | 
			
		||||
				continue;
 | 
			
		||||
			}
 | 
			
		||||
			fun := responseBody.Choices[0].Delta.FunctionCall
 | 
			
		||||
			if functionCall && fun.Name == "" {
 | 
			
		||||
				arguments = append(arguments, fun.Arguments)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !utils.IsEmptyValue(fun) {
 | 
			
		||||
				functionName = fun.Name
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if f != nil {
 | 
			
		||||
					functionCall = true
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 初始化 role
 | 
			
		||||
			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
 | 
			
		||||
				message.Role = responseBody.Choices[0].Delta.Role
 | 
			
		||||
@@ -124,49 +101,6 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if functionCall { // 调用函数完成任务
 | 
			
		||||
			var params map[string]interface{}
 | 
			
		||||
			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
 | 
			
		||||
 | 
			
		||||
			// for creating image, check if the user's img_calls > 0
 | 
			
		||||
			if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
			} else {
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if functionName == types.FuncImage {
 | 
			
		||||
					params["user_id"] = userVo.Id
 | 
			
		||||
					params["role_id"] = role.Id
 | 
			
		||||
					params["chat_id"] = session.ChatId
 | 
			
		||||
					params["icon"] = "/images/avatar/mid_journey.png"
 | 
			
		||||
					params["session_id"] = session.SessionId
 | 
			
		||||
				}
 | 
			
		||||
				data, err := f.Invoke(params)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					msg := "调用函数出错:" + err.Error()
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: msg,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, msg)
 | 
			
		||||
				} else {
 | 
			
		||||
					content := data
 | 
			
		||||
					if functionName == types.FuncImage {
 | 
			
		||||
						content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景", params["prompt"])
 | 
			
		||||
						// update user's img_calls
 | 
			
		||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: content,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, content)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
@@ -179,7 +113,7 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext && functionCall == false {
 | 
			
		||||
			if h.App.ChatConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
@@ -187,11 +121,6 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				useContext := true
 | 
			
		||||
				if functionCall {
 | 
			
		||||
					useContext = false
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
@@ -205,7 +134,7 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: useContext,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
@@ -215,15 +144,7 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				var totalTokens = 0
 | 
			
		||||
				if functionCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
					tokens, _ := utils.CalcTokens(functionName, req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
				} else {
 | 
			
		||||
					totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				}
 | 
			
		||||
				totalTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				totalTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
@@ -234,7 +155,7 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: useContext,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
 
 | 
			
		||||
@@ -214,20 +214,45 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.OpenAI.Temperature
 | 
			
		||||
		// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
 | 
			
		||||
		break
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.OpenAI.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
 | 
			
		||||
		// OpenAI 支持函数功能
 | 
			
		||||
		if h.App.SysConfig.EnabledFunction {
 | 
			
		||||
			var functions = make([]types.Function, 0)
 | 
			
		||||
			for _, f := range types.InnerFunctions {
 | 
			
		||||
				functions = append(functions, f)
 | 
			
		||||
		var items []model.Function
 | 
			
		||||
		res := h.db.Where("enabled", true).Find(&items)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tools = make([]interface{}, 0)
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var parameters map[string]interface{}
 | 
			
		||||
			err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			req.Functions = functions
 | 
			
		||||
			required := parameters["required"]
 | 
			
		||||
			delete(parameters, "required")
 | 
			
		||||
			tools = append(tools, gin.H{
 | 
			
		||||
				"type": "function",
 | 
			
		||||
				"function": gin.H{
 | 
			
		||||
					"name":        v.Name,
 | 
			
		||||
					"description": v.Description,
 | 
			
		||||
					"parameters":  parameters,
 | 
			
		||||
					"required":    required,
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(tools) > 0 {
 | 
			
		||||
			req.Tools = tools
 | 
			
		||||
			req.ToolChoice = "auto"
 | 
			
		||||
		}
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.XunFei.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
@@ -242,11 +267,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		} else {
 | 
			
		||||
			// calculate the tokens of current request, to prevent to exceeding the max tokens num
 | 
			
		||||
			tokens := req.MaxTokens
 | 
			
		||||
			for _, f := range types.InnerFunctions {
 | 
			
		||||
				tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model)
 | 
			
		||||
				tokens += tks
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
			
		||||
			tokens += tks
 | 
			
		||||
			// loading the role context
 | 
			
		||||
			var messages []types.Message
 | 
			
		||||
			err := utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	req2 "github.com/imroc/req/v3"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -56,8 +56,8 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var functionCall = false
 | 
			
		||||
		var functionName string
 | 
			
		||||
		var function model.Function
 | 
			
		||||
		var toolCall = false
 | 
			
		||||
		var arguments = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
@@ -75,24 +75,26 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			fun := responseBody.Choices[0].Delta.FunctionCall
 | 
			
		||||
			if functionCall && fun.Name == "" {
 | 
			
		||||
				arguments = append(arguments, fun.Arguments)
 | 
			
		||||
				continue
 | 
			
		||||
			var fun types.ToolCall
 | 
			
		||||
			if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
 | 
			
		||||
				fun = responseBody.Choices[0].Delta.ToolCalls[0]
 | 
			
		||||
				if toolCall && fun.Function.Name == "" {
 | 
			
		||||
					arguments = append(arguments, fun.Function.Arguments)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !utils.IsEmptyValue(fun) {
 | 
			
		||||
				functionName = fun.Name
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if f != nil {
 | 
			
		||||
					functionCall = true
 | 
			
		||||
				res := h.db.Where("name = ?", fun.Function.Name).First(&function)
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					toolCall = true
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "tool_calls" { // 函数调用完毕
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@@ -121,47 +123,35 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if functionCall { // 调用函数完成任务
 | 
			
		||||
		if toolCall { // 调用函数完成任务
 | 
			
		||||
			var params map[string]interface{}
 | 
			
		||||
			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
 | 
			
		||||
 | 
			
		||||
			// for creating image, check if the user's img_calls > 0
 | 
			
		||||
			if functionName == types.FuncImage && userVo.ImgCalls <= 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
 | 
			
		||||
			params["user_id"] = userVo.Id
 | 
			
		||||
			var apiRes types.BizVo
 | 
			
		||||
			r, err := req2.C().R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
				SetHeader("Authorization", function.Token).
 | 
			
		||||
				SetBody(params).
 | 
			
		||||
				SetSuccessResult(&apiRes).Post(function.Action)
 | 
			
		||||
			errMsg := ""
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errMsg = err.Error()
 | 
			
		||||
			} else if r.IsErrorState() {
 | 
			
		||||
				errMsg = r.Err.Error()
 | 
			
		||||
			}
 | 
			
		||||
			if errMsg != "" || apiRes.Code != types.Success {
 | 
			
		||||
				msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: msg,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, msg)
 | 
			
		||||
			} else {
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				// translate prompt
 | 
			
		||||
				if functionName == types.FuncImage {
 | 
			
		||||
					const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
 | 
			
		||||
					r, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
 | 
			
		||||
					if err == nil {
 | 
			
		||||
						params["prompt"] = r
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				data, err := f.Invoke(params)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					msg := "调用函数出错:" + err.Error()
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: msg,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, msg)
 | 
			
		||||
				} else {
 | 
			
		||||
					content := data
 | 
			
		||||
					if functionName == types.FuncImage {
 | 
			
		||||
						content = fmt.Sprintf("下面是根据您的描述创作的图片,他们描绘了 【%s】 的场景。%s", params["prompt"], data)
 | 
			
		||||
						// update user's img_calls
 | 
			
		||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: content,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, content)
 | 
			
		||||
				}
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: apiRes.Data,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(apiRes.Data))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -177,7 +167,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext && functionCall == false {
 | 
			
		||||
			if h.App.ChatConfig.EnableContext && toolCall == false {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
@@ -186,7 +176,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				useContext := true
 | 
			
		||||
				if functionCall {
 | 
			
		||||
				if toolCall {
 | 
			
		||||
					useContext = false
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@@ -214,8 +204,8 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				var totalTokens = 0
 | 
			
		||||
				if functionCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
					tokens, _ := utils.CalcTokens(functionName, req.Model)
 | 
			
		||||
				if toolCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
					tokens, _ := utils.CalcTokens(function.Name, req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										235
									
								
								api/handler/function_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								api/handler/function_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,235 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type FunctionHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	config        types.ChatPlusApiConfig
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	proxyURL      string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
 | 
			
		||||
	return &FunctionHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: server,
 | 
			
		||||
		},
 | 
			
		||||
		db:            db,
 | 
			
		||||
		config:        config.ApiConfig,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		proxyURL:      config.ProxyURL,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type resVo struct {
 | 
			
		||||
	Code    types.BizCode `json:"code"`
 | 
			
		||||
	Message string        `json:"message"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		Title     string     `json:"title"`
 | 
			
		||||
		UpdatedAt string     `json:"updated_at"`
 | 
			
		||||
		Items     []dataItem `json:"items"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dataItem struct {
 | 
			
		||||
	Title  string `json:"title"`
 | 
			
		||||
	Url    string `json:"url"`
 | 
			
		||||
	Remark string `json:"remark"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WeiBo 微博热搜
 | 
			
		||||
func (h *FunctionHandler) WeiBo(c *gin.Context) {
 | 
			
		||||
	if h.config.Token == "" {
 | 
			
		||||
		resp.ERROR(c, "无效的 API Token")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("AppId", h.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		resp.ERROR(c, res.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
 | 
			
		||||
	for i, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, strings.Join(builder, "\n\n"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ZaoBao 今日早报
 | 
			
		||||
func (h *FunctionHandler) ZaoBao(c *gin.Context) {
 | 
			
		||||
	if h.config.Token == "" {
 | 
			
		||||
		resp.ERROR(c, "无效的 API Token")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("AppId", h.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		resp.ERROR(c, res.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
 | 
			
		||||
	for _, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, v.Title)
 | 
			
		||||
	}
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
 | 
			
		||||
	resp.SUCCESS(c, strings.Join(builder, "\n\n"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imgReq struct {
 | 
			
		||||
	Model  string `json:"model"`
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	N      int    `json:"n"`
 | 
			
		||||
	Size   string `json:"size"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imgRes struct {
 | 
			
		||||
	Created int64 `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
		RevisedPrompt string `json:"revised_prompt"`
 | 
			
		||||
		Url           string `json:"url"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Code    interface{} `json:"code"`
 | 
			
		||||
		Message string      `json:"message"`
 | 
			
		||||
		Param   interface{} `json:"param"`
 | 
			
		||||
		Type    string      `json:"type"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Dall3 DallE3 AI 绘图
 | 
			
		||||
func (h *FunctionHandler) Dall3(c *gin.Context) {
 | 
			
		||||
	var params map[string]interface{}
 | 
			
		||||
	if err := c.ShouldBindJSON(¶ms); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("绘画参数:%+v", params)
 | 
			
		||||
	// check img calls
 | 
			
		||||
	var user model.User
 | 
			
		||||
	tx := h.db.Where("id = ?", params["user_id"]).First(&user)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "当前用户不存在!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.ImgCalls <= 0 {
 | 
			
		||||
		resp.ERROR(c, "当前用户的绘图次数额度不足!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
			
		||||
	// get image generation API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	tx = h.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// get image generation api URL
 | 
			
		||||
	var conf model.Config
 | 
			
		||||
	var chatConfig types.ChatConfig
 | 
			
		||||
	tx = h.db.Where("marker", "chat").First(&conf)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "error with get chat configs:"+tx.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := utils.JsonDecode(conf.Config, &chatConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with decode chat config: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// translate prompt
 | 
			
		||||
	const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
 | 
			
		||||
	pt, err := utils.OpenAIRequest(fmt.Sprintf(translatePromptTemplate, params["prompt"]), apiKey.Value, h.App.Config.ProxyURL, chatConfig.OpenAI.ApiURL)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		prompt = pt
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := chatConfig.DallApiURL
 | 
			
		||||
	if utils.IsEmptyValue(apiURL) {
 | 
			
		||||
		apiURL = "https://api.openai.com/v1/images/generations"
 | 
			
		||||
	}
 | 
			
		||||
	imgNum := chatConfig.DallImgNum
 | 
			
		||||
	if imgNum <= 0 {
 | 
			
		||||
		imgNum = 1
 | 
			
		||||
	}
 | 
			
		||||
	var res imgRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	var request *req.Request
 | 
			
		||||
	if strings.Contains(apiURL, "api.openai.com") {
 | 
			
		||||
		request = req.C().SetProxyURL(h.proxyURL).R()
 | 
			
		||||
	} else {
 | 
			
		||||
		request = req.C().R()
 | 
			
		||||
	}
 | 
			
		||||
	r, err := request.SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(imgReq{
 | 
			
		||||
			Model:  "dall-e-3",
 | 
			
		||||
			Prompt: prompt,
 | 
			
		||||
			N:      imgNum,
 | 
			
		||||
			Size:   "1024x1024",
 | 
			
		||||
		}).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		SetSuccessResult(&res).Post(apiURL)
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 存储图片
 | 
			
		||||
	imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "下载图片失败: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n\n", prompt, imgURL)
 | 
			
		||||
	// update user's img_calls
 | 
			
		||||
	h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, content)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										12
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								api/main.go
									
									
									
									
									
								
							@@ -8,7 +8,6 @@ import (
 | 
			
		||||
	"chatplus/handler/chatimpl"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/fun"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/service/payment"
 | 
			
		||||
@@ -115,9 +114,6 @@ func main() {
 | 
			
		||||
			return xdb.NewWithBuffer(cBuff)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 创建函数
 | 
			
		||||
		fx.Provide(fun.NewFunctions),
 | 
			
		||||
 | 
			
		||||
		// 创建控制器
 | 
			
		||||
		fx.Provide(handler.NewChatRoleHandler),
 | 
			
		||||
		fx.Provide(handler.NewUserHandler),
 | 
			
		||||
@@ -358,6 +354,14 @@ func main() {
 | 
			
		||||
			group.GET("token", h.GenToken)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(handler.NewFunctionHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/function/")
 | 
			
		||||
			group.POST("weibo", h.WeiBo)
 | 
			
		||||
			group.POST("zaobao", h.ZaoBao)
 | 
			
		||||
			group.POST("dalle3", h.Dall3)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(handler.NewTestHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
 | 
			
		||||
			group := s.Engine.Group("/test/")
 | 
			
		||||
 
 | 
			
		||||
@@ -1,116 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// AI 绘画函数
 | 
			
		||||
 | 
			
		||||
type FuncImage struct {
 | 
			
		||||
	name          string
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	proxyURL      string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewImageFunc(db *gorm.DB, manager *oss.UploaderManager, config *types.AppConfig) FuncImage {
 | 
			
		||||
	return FuncImage{
 | 
			
		||||
		db:            db,
 | 
			
		||||
		name:          "DALL-E3 绘画",
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		proxyURL:      config.ProxyURL,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imgReq struct {
 | 
			
		||||
	Model  string `json:"model"`
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	N      int    `json:"n"`
 | 
			
		||||
	Size   string `json:"size"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imgRes struct {
 | 
			
		||||
	Created int64 `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
		RevisedPrompt string `json:"revised_prompt"`
 | 
			
		||||
		Url           string `json:"url"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Code    interface{} `json:"code"`
 | 
			
		||||
		Message string      `json:"message"`
 | 
			
		||||
		Param   interface{} `json:"param"`
 | 
			
		||||
		Type    string      `json:"type"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncImage) Invoke(params map[string]interface{}) (string, error) {
 | 
			
		||||
	logger.Infof("绘画参数:%+v", params)
 | 
			
		||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
			
		||||
	// get image generation API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	tx := f.db.Where("platform = ? AND type = ?", types.OpenAI, "img").Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with get generation API KEY: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// get image generation api URL
 | 
			
		||||
	var conf model.Config
 | 
			
		||||
	var chatConfig types.ChatConfig
 | 
			
		||||
	tx = f.db.Where("marker", "chat").First(&conf)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with get chat configs: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := utils.JsonDecode(conf.Config, &chatConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with decode chat config: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := chatConfig.DallApiURL
 | 
			
		||||
	if utils.IsEmptyValue(apiURL) {
 | 
			
		||||
		apiURL = "https://api.openai.com/v1/images/generations"
 | 
			
		||||
	}
 | 
			
		||||
	imgNum := chatConfig.DallImgNum
 | 
			
		||||
	if imgNum <= 0 {
 | 
			
		||||
		imgNum = 1
 | 
			
		||||
	}
 | 
			
		||||
	var res imgRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().SetProxyURL(f.proxyURL).R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(imgReq{
 | 
			
		||||
			Model:  "dall-e-3",
 | 
			
		||||
			Prompt: prompt,
 | 
			
		||||
			N:      imgNum,
 | 
			
		||||
			Size:   "1024x1024",
 | 
			
		||||
		}).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		SetSuccessResult(&res).Post(apiURL)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
	// 存储图片
 | 
			
		||||
	imgURL, err := f.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("下载图片失败: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//logger.Info(imgURL)
 | 
			
		||||
	return fmt.Sprintf("\n\n\n", imgURL), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncImage) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncImage{}
 | 
			
		||||
@@ -1,40 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Function interface {
 | 
			
		||||
	Invoke(map[string]interface{}) (string, error)
 | 
			
		||||
	Name() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type resVo struct {
 | 
			
		||||
	Code    types.BizCode `json:"code"`
 | 
			
		||||
	Message string        `json:"message"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		Title     string     `json:"title"`
 | 
			
		||||
		UpdatedAt string     `json:"updated_at"`
 | 
			
		||||
		Items     []dataItem `json:"items"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dataItem struct {
 | 
			
		||||
	Title  string `json:"title"`
 | 
			
		||||
	Url    string `json:"url"`
 | 
			
		||||
	Remark string `json:"remark"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFunctions(config *types.AppConfig, db *gorm.DB, manager *oss.UploaderManager) map[string]Function {
 | 
			
		||||
	return map[string]Function{
 | 
			
		||||
		types.FuncZaoBao:   NewZaoBao(config.ApiConfig),
 | 
			
		||||
		types.FuncWeibo:    NewWeiboHot(config.ApiConfig),
 | 
			
		||||
		types.FuncHeadLine: NewHeadLines(config.ApiConfig),
 | 
			
		||||
		types.FuncImage:    NewImageFunc(db, manager, config),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,58 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 今日头条函数实现
 | 
			
		||||
 | 
			
		||||
type FuncHeadlines struct {
 | 
			
		||||
	name   string
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
 | 
			
		||||
	return FuncHeadlines{
 | 
			
		||||
		name:   "今日头条",
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
 | 
			
		||||
	if f.config.Token == "" {
 | 
			
		||||
		return "", errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/headline/fetch", f.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("AppId", f.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("%v%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return "", errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
 | 
			
		||||
	for i, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Url, v.Remark))
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(builder, "\n\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncHeadlines) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncHeadlines{}
 | 
			
		||||
@@ -1,58 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 微博热搜函数实现
 | 
			
		||||
 | 
			
		||||
type FuncWeiboHot struct {
 | 
			
		||||
	name   string
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
 | 
			
		||||
	return FuncWeiboHot{
 | 
			
		||||
		name:   "微博热搜",
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
 | 
			
		||||
	if f.config.Token == "" {
 | 
			
		||||
		return "", errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/weibo/fetch", f.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("AppId", f.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("%v%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return "", errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
 | 
			
		||||
	for i, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(builder, "\n\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncWeiboHot) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncWeiboHot{}
 | 
			
		||||
@@ -1,59 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 每日早报函数实现
 | 
			
		||||
 | 
			
		||||
type FuncZaoBao struct {
 | 
			
		||||
	name   string
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
 | 
			
		||||
	return FuncZaoBao{
 | 
			
		||||
		name:   "每日早报",
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
 | 
			
		||||
	if f.config.Token == "" {
 | 
			
		||||
		return "", errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/zaobao/fetch", f.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("AppId", f.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("%v%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return "", errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
 | 
			
		||||
	for _, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, v.Title)
 | 
			
		||||
	}
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
 | 
			
		||||
	return strings.Join(builder, "\n\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncZaoBao) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncZaoBao{}
 | 
			
		||||
@@ -104,7 +104,7 @@ func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author.ID == s.State.User.ID {
 | 
			
		||||
	if m.Author == nil || m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -136,7 +136,7 @@ func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author.ID == s.State.User.ID {
 | 
			
		||||
	if m.Author == nil || m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,6 @@ type Function struct {
 | 
			
		||||
	Label       string
 | 
			
		||||
	Description string
 | 
			
		||||
	Parameters  string
 | 
			
		||||
	Required    string
 | 
			
		||||
	Action      string
 | 
			
		||||
	Token       string
 | 
			
		||||
	Enabled     bool
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ package vo
 | 
			
		||||
 | 
			
		||||
type Parameters struct {
 | 
			
		||||
	Type       string              `json:"type"`
 | 
			
		||||
	Required   []string            `json:"required"`
 | 
			
		||||
	Required   []string            `json:"required,omitempty"`
 | 
			
		||||
	Properties map[string]Property `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -17,7 +17,6 @@ type Function struct {
 | 
			
		||||
	Label       string     `json:"label"`
 | 
			
		||||
	Description string     `json:"description"`
 | 
			
		||||
	Parameters  Parameters `json:"parameters"`
 | 
			
		||||
	Required    []string   `json:"required"`
 | 
			
		||||
	Action      string     `json:"action"`
 | 
			
		||||
	Token       string     `json:"token"`
 | 
			
		||||
	Enabled     bool       `json:"enabled"`
 | 
			
		||||
 
 | 
			
		||||
@@ -24,3 +24,5 @@ ALTER TABLE `chatgpt_mj_jobs` ADD `use_proxy` TINYINT(1) NOT NULL DEFAULT '0' CO
 | 
			
		||||
ALTER TABLE `chatgpt_mj_jobs` CHANGE `img_url` `img_url` VARCHAR(400) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '图片URL';
 | 
			
		||||
 | 
			
		||||
ALTER TABLE `chatgpt_functions` ADD `token` VARCHAR(255) NULL COMMENT 'API授权token' AFTER `action`;
 | 
			
		||||
 | 
			
		||||
ALTER TABLE `chatgpt_functions` DROP `required`;
 | 
			
		||||
 
 | 
			
		||||
@@ -220,7 +220,7 @@ const rowEdit = function (index, row) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const addRow = function () {
 | 
			
		||||
  item.value = {enabled:true}
 | 
			
		||||
  item.value = {enabled: true}
 | 
			
		||||
  params.value = []
 | 
			
		||||
  showDialog.value = true
 | 
			
		||||
}
 | 
			
		||||
@@ -238,7 +238,7 @@ const save = function () {
 | 
			
		||||
          required.push(params.value[i].name)
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      item.value.parameters = {type: "object", "properties": properties, "required": required}
 | 
			
		||||
      item.value.parameters = {type: "object", "properties": properties, required: required}
 | 
			
		||||
      httpPost('/api/admin/function/save', item.value).then((res) => {
 | 
			
		||||
        ElMessage.success('操作成功')
 | 
			
		||||
        console.log(res.data)
 | 
			
		||||
@@ -274,7 +274,7 @@ const removeParam = function (index) {
 | 
			
		||||
  params.value.splice(index, 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const functionSet = (filed,row) => {
 | 
			
		||||
const functionSet = (filed, row) => {
 | 
			
		||||
  httpPost('/api/admin/function/set', {id: row.id, filed: filed, value: row[filed]}).then(() => {
 | 
			
		||||
    ElMessage.success("操作成功!")
 | 
			
		||||
  }).catch(e => {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user