diff --git a/api/core/types/chat.go b/api/core/types/chat.go
index 9464ec8b..42c86a2b 100644
--- a/api/core/types/chat.go
+++ b/api/core/types/chat.go
@@ -57,6 +57,7 @@ type ChatSession struct {
ClientIP string `json:"client_ip"` // 客户端 IP
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
Model ChatModel `json:"model"` // GPT 模型
+ Tools string `json:"tools"` // 函数
}
type ChatModel struct {
diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go
index 164c7fd9..58b9870a 100644
--- a/api/handler/chatimpl/chat_handler.go
+++ b/api/handler/chatimpl/chat_handler.go
@@ -73,6 +73,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
modelId := h.GetInt(c, "model_id", 0)
+ tools := c.Query("tools")
client := types.NewWsClient(ws)
var chatRole model.ChatRole
@@ -99,6 +100,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
SessionId: sessionId,
ClientIP: c.ClientIP(),
UserId: h.GetLoginUserId(c),
+ Tools: tools,
}
// use old chat data override the chat model and role ID
@@ -211,34 +213,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
- // OpenAI 支持函数功能
- var items []model.Function
- res = h.DB.Where("enabled", true).Find(&items)
- if res.Error == nil {
- var tools = make([]types.Tool, 0)
- for _, v := range items {
- var parameters map[string]interface{}
- err = utils.JsonDecode(v.Parameters, ¶meters)
- if err != nil {
- continue
- }
- tool := types.Tool{
- Type: "function",
- Function: types.Function{
- Name: v.Name,
- Description: v.Description,
- Parameters: parameters,
- },
- }
- if v, ok := parameters["required"]; v == nil || !ok {
- tool.Function.Parameters["required"] = []string{}
- }
- tools = append(tools, tool)
- }
- if len(tools) > 0 {
- req.Tools = tools
- req.ToolChoice = "auto"
+ if session.Tools != "" {
+ toolIds := strings.Split(session.Tools, ",")
+ var items []model.Function
+ res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
+ if res.Error == nil {
+ var tools = make([]types.Tool, 0)
+ for _, v := range items {
+ var parameters map[string]interface{}
+ err = utils.JsonDecode(v.Parameters, ¶meters)
+ if err != nil {
+ continue
+ }
+ tool := types.Tool{
+ Type: "function",
+ Function: types.Function{
+ Name: v.Name,
+ Description: v.Description,
+ Parameters: parameters,
+ },
+ }
+ if v, ok := parameters["required"]; v == nil || !ok {
+ tool.Function.Parameters["required"] = []string{}
+ }
+ tools = append(tools, tool)
+ }
+
+ if len(tools) > 0 {
+ req.Tools = tools
+ req.ToolChoice = "auto"
+ }
}
}
diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go
index 6917efde..f1838d4d 100644
--- a/api/handler/function_handler.go
+++ b/api/handler/function_handler.go
@@ -8,15 +8,16 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
+ "errors"
+ "fmt"
"geekai/core"
"geekai/core/types"
"geekai/service/dalle"
"geekai/service/oss"
"geekai/store/model"
+ "geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
- "errors"
- "fmt"
"strings"
"time"
@@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
resp.SUCCESS(c, content)
}
+
+// List 获取所有的工具函数列表
+func (h *FunctionHandler) List(c *gin.Context) {
+ var items []model.Function
+ err := h.DB.Where("enabled", true).Find(&items).Error
+ if err != nil {
+ resp.ERROR(c, err.Error())
+ return
+ }
+
+ tools := make([]vo.Function, 0)
+ for _, v := range items {
+ var f vo.Function
+ err = utils.CopyObject(v, &f)
+ if err != nil {
+ continue
+ }
+ f.Action = ""
+ f.Token = ""
+ tools = append(tools, f)
+ }
+
+ resp.SUCCESS(c, tools)
+}
diff --git a/api/main.go b/api/main.go
index 82cf02f7..274f67b2 100644
--- a/api/main.go
+++ b/api/main.go
@@ -433,6 +433,7 @@ func main() {
group.POST("weibo", h.WeiBo)
group.POST("zaobao", h.ZaoBao)
group.POST("dalle3", h.Dall3)
+ group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
group := s.Engine.Group("/api/admin/chat/")
diff --git a/web/src/assets/css/chat-plus.styl b/web/src/assets/css/chat-plus.styl
index b1dbce43..9edc62a3 100644
--- a/web/src/assets/css/chat-plus.styl
+++ b/web/src/assets/css/chat-plus.styl
@@ -427,4 +427,11 @@ $borderColor = #4676d0;
.el-image {
width 360px;
}
+}
+
+.tools-dropdown {
+ width auto
+ .el-icon {
+ margin-left 5px;
+ }
}
\ No newline at end of file
diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue
index b500026c..b6012381 100644
--- a/web/src/views/ChatPlus.vue
+++ b/web/src/views/ChatPlus.vue
@@ -100,11 +100,27 @@
+