From 67d83041d763bb69ab1a762fe43d3ae31bf2799e Mon Sep 17 00:00:00 2001 From: RockYang Date: Wed, 4 Sep 2024 14:53:21 +0800 Subject: [PATCH] user can select function tools by themself --- api/core/types/chat.go | 1 + api/handler/chatimpl/chat_handler.go | 59 +++++++++++++++------------- api/handler/function_handler.go | 29 +++++++++++++- api/main.go | 1 + web/src/assets/css/chat-plus.styl | 7 ++++ web/src/views/ChatPlus.vue | 55 ++++++++++++++++++++++---- 6 files changed, 116 insertions(+), 36 deletions(-) diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 9464ec8b..42c86a2b 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -57,6 +57,7 @@ type ChatSession struct { ClientIP string `json:"client_ip"` // 客户端 IP ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 Model ChatModel `json:"model"` // GPT 模型 + Tools string `json:"tools"` // 函数 } type ChatModel struct { diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 164c7fd9..58b9870a 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -73,6 +73,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { roleId := h.GetInt(c, "role_id", 0) chatId := c.Query("chat_id") modelId := h.GetInt(c, "model_id", 0) + tools := c.Query("tools") client := types.NewWsClient(ws) var chatRole model.ChatRole @@ -99,6 +100,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { SessionId: sessionId, ClientIP: c.ClientIP(), UserId: h.GetLoginUserId(c), + Tools: tools, } // use old chat data override the chat model and role ID @@ -211,34 +213,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } req.Temperature = session.Model.Temperature req.MaxTokens = session.Model.MaxTokens - // OpenAI 支持函数功能 - var items []model.Function - res = h.DB.Where("enabled", true).Find(&items) - if res.Error == nil { - var tools = make([]types.Tool, 0) - for _, v := range items { - var parameters map[string]interface{} - err = utils.JsonDecode(v.Parameters, ¶meters) - if err != nil { - continue - } - tool := types.Tool{ - Type: "function", - Function: types.Function{ - Name: v.Name, - Description: v.Description, - Parameters: parameters, - }, - } - if v, ok := parameters["required"]; v == nil || !ok { - tool.Function.Parameters["required"] = []string{} - } - tools = append(tools, tool) - } - if len(tools) > 0 { - req.Tools = tools - req.ToolChoice = "auto" + if session.Tools != "" { + toolIds := strings.Split(session.Tools, ",") + var items []model.Function + res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items) + if res.Error == nil { + var tools = make([]types.Tool, 0) + for _, v := range items { + var parameters map[string]interface{} + err = utils.JsonDecode(v.Parameters, ¶meters) + if err != nil { + continue + } + tool := types.Tool{ + Type: "function", + Function: types.Function{ + Name: v.Name, + Description: v.Description, + Parameters: parameters, + }, + } + if v, ok := parameters["required"]; v == nil || !ok { + tool.Function.Parameters["required"] = []string{} + } + tools = append(tools, tool) + } + + if len(tools) > 0 { + req.Tools = tools + req.ToolChoice = "auto" + } } } diff --git a/api/handler/function_handler.go b/api/handler/function_handler.go index 6917efde..f1838d4d 100644 --- a/api/handler/function_handler.go +++ b/api/handler/function_handler.go @@ -8,15 +8,16 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "errors" + "fmt" "geekai/core" "geekai/core/types" "geekai/service/dalle" "geekai/service/oss" "geekai/store/model" + "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "errors" - "fmt" "strings" "time" @@ -224,3 +225,27 @@ func (h *FunctionHandler) Dall3(c *gin.Context) { resp.SUCCESS(c, content) } + +// List 获取所有的工具函数列表 +func (h *FunctionHandler) List(c *gin.Context) { + var items []model.Function + err := h.DB.Where("enabled", true).Find(&items).Error + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + tools := make([]vo.Function, 0) + for _, v := range items { + var f vo.Function + err = utils.CopyObject(v, &f) + if err != nil { + continue + } + f.Action = "" + f.Token = "" + tools = append(tools, f) + } + + resp.SUCCESS(c, tools) +} diff --git a/api/main.go b/api/main.go index 82cf02f7..274f67b2 100644 --- a/api/main.go +++ b/api/main.go @@ -433,6 +433,7 @@ func main() { group.POST("weibo", h.WeiBo) group.POST("zaobao", h.ZaoBao) group.POST("dalle3", h.Dall3) + group.GET("list", h.List) }), fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { group := s.Engine.Group("/api/admin/chat/") diff --git a/web/src/assets/css/chat-plus.styl b/web/src/assets/css/chat-plus.styl index b1dbce43..9edc62a3 100644 --- a/web/src/assets/css/chat-plus.styl +++ b/web/src/assets/css/chat-plus.styl @@ -427,4 +427,11 @@ $borderColor = #4676d0; .el-image { width 360px; } +} + +.tools-dropdown { + width auto + .el-icon { + margin-left 5px; + } } \ No newline at end of file diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index b500026c..b6012381 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -100,11 +100,27 @@ + + + + + - - - - + + + +
@@ -202,7 +218,7 @@ import {nextTick, onMounted, onUnmounted, ref, watch} from 'vue' import ChatPrompt from "@/components/ChatPrompt.vue"; import ChatReply from "@/components/ChatReply.vue"; -import {Delete, Edit, More, Plus, Promotion, Search, Share, VideoPause} from '@element-plus/icons-vue' +import {Delete, Edit, InfoFilled, More, Plus, Promotion, Search, Share, VideoPause} from '@element-plus/icons-vue' import 'highlight.js/styles/a11y-dark.css' import { isMobile, @@ -222,6 +238,7 @@ import FileSelect from "@/components/FileSelect.vue"; import FileList from "@/components/FileList.vue"; import ChatSetting from "@/components/ChatSetting.vue"; import BackTop from "@/components/BackTop.vue"; +import {showMessageError} from "@/utils/dialog"; const title = ref('GeekAI-智能助手'); const models = ref([]) @@ -253,6 +270,9 @@ const listStyle = ref(store.chatListStyle) watch(() => store.chatListStyle, (newValue) => { listStyle.value = newValue }); +const tools = ref([]) +const toolSelected = ref([]) +const loadHistory = ref(false) // 初始化 ChatID chatId.value = router.currentRoute.value.params.id @@ -294,6 +314,13 @@ httpGet("/api/config/get?key=notice").then(res => { ElMessage.error("获取系统配置失败:" + e.message) }) +// 获取工具函数 +httpGet("/api/function/list").then(res => { + tools.value = res.data +}).catch(e => { + showMessageError("获取工具函数失败:" + e.message) +}) + onMounted(() => { resizeElement(); initData() @@ -464,9 +491,19 @@ const newChat = () => { }; showStopGenerate.value = false; router.push(`/chat/${chatId.value}`) + loadHistory.value = true connect() } +// 切换工具 +const changeTool = () => { + if (!isLogin.value) { + return; + } + loadHistory.value = false + socket.value.close() +} + // 切换会话 const loadChat = function (chat) { @@ -485,6 +522,7 @@ const loadChat = function (chat) { chatId.value = chat.chat_id; showStopGenerate.value = false; router.push(`/chat/${chatId.value}`) + loadHistory.value = true socket.value.close() } @@ -578,10 +616,13 @@ const connect = function () { } loading.value = true - const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}`); + const toolIds = toolSelected.value.join(',') + const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}&tools=${toolIds}`); _socket.addEventListener('open', () => { enableInput() - loadChatHistory(chatId.value) + if (loadHistory.value) { + loadChatHistory(chatId.value) + } loading.value = false });