From 43bfac99b6160b78c6bec678c230595e803a3653 Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 11 Mar 2024 14:09:19 +0800 Subject: [PATCH] feat: replace Tools param with Function param for OpenAI chat API --- api/core/app_server.go | 4 +- api/core/types/locked_map.go | 2 +- api/handler/chatimpl/azure_handler.go | 2 +- api/handler/chatimpl/baidu_handler.go | 2 +- api/handler/chatimpl/chat_handler.go | 78 ++++++++++++------------- api/handler/chatimpl/chatglm_handler.go | 2 +- api/handler/chatimpl/openai_handler.go | 8 ++- api/handler/chatimpl/qwen_handler.go | 2 +- api/handler/chatimpl/xunfei_handler.go | 2 +- 9 files changed, 52 insertions(+), 50 deletions(-) diff --git a/api/core/app_server.go b/api/core/app_server.go index c166c0dc..c770d413 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -28,7 +28,7 @@ type AppServer struct { Debug bool Config *types.AppConfig Engine *gin.Engine - ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message + ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message ChatConfig *types.ChatConfig // chat config cache SysConfig *types.SystemConfig // system config cache @@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer { Debug: false, Config: appConfig, Engine: gin.Default(), - ChatContexts: types.NewLMap[string, []interface{}](), + ChatContexts: types.NewLMap[string, []types.Message](), ChatSession: types.NewLMap[string, *types.ChatSession](), ChatClients: types.NewLMap[string, *types.WsClient](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), diff --git a/api/core/types/locked_map.go b/api/core/types/locked_map.go index ede72f34..26ed6f46 100644 --- a/api/core/types/locked_map.go +++ b/api/core/types/locked_map.go @@ -9,7 +9,7 @@ type MKey interface { string | int | uint } type MValue interface { - *WsClient | *ChatSession | context.CancelFunc | []interface{} + *WsClient | *ChatSession | context.CancelFunc | []Message } type LMap[K MKey, T MValue] struct { lock sync.RWMutex diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index 460ac9d7..82a96d18 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -19,7 +19,7 @@ import ( // 微软 Azure 模型消息发送实现 func (h *ChatHandler) sendAzureMessage( - chatCtx []interface{}, + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, ctx context.Context, diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go index 85f786da..bf50feeb 100644 --- a/api/handler/chatimpl/baidu_handler.go +++ b/api/handler/chatimpl/baidu_handler.go @@ -36,7 +36,7 @@ type baiduResp struct { // 百度文心一言消息发送实现 func (h *ChatHandler) sendBaiduMessage( - chatCtx []interface{}, + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, ctx context.Context, diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 58fcaa43..b7b1bc14 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio utils.ReplyMessage(ws, ErrImg) return nil } + + // 检查 prompt 长度是否超过了当前模型允许的最大上下文长度 + promptTokens, err := utils.CalcTokens(prompt, session.Model.Value) + if promptTokens > types.GetModelMaxToken(session.Model.Value) { + utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!") + return nil + } + var req = types.ApiRequest{ Model: session.Model.Value, Stream: true, @@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } var tools = make([]interface{}, 0) - var functions = make([]interface{}, 0) for _, v := range items { var parameters map[string]interface{} err = utils.JsonDecode(v.Parameters, ¶meters) @@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio "required": required, }, }) - functions = append(functions, gin.H{ - "name": v.Name, - "description": v.Description, - "parameters": parameters, - "required": required, - }) } - //if len(tools) > 0 { - // req.Tools = tools - // req.ToolChoice = "auto" - //} - if len(functions) > 0 { - req.Functions = functions + if len(tools) > 0 { + req.Tools = tools + req.ToolChoice = "auto" } case types.XunFei: @@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } // 加载聊天上下文 - var chatCtx []interface{} + chatCtx := make([]types.Message, 0) + messages := make([]types.Message, 0) if h.App.ChatConfig.EnableContext { if h.App.ChatContexts.Has(session.ChatId) { - chatCtx = h.App.ChatContexts.Get(session.ChatId) + messages = h.App.ChatContexts.Get(session.ChatId) } else { - // calculate the tokens of current request, to prevent to exceeding the max tokens num - tokens := req.MaxTokens - tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model) - tokens += tks - // loading the role context - var messages []types.Message - err := utils.JsonDecode(role.Context, &messages) - if err == nil { - for _, v := range messages { - tks, _ := utils.CalcTokens(v.Content, req.Model) - if tokens+tks >= types.GetModelMaxToken(req.Model) { - break - } - tokens += tks - chatCtx = append(chatCtx, v) - } - } - - // loading recent chat history as chat context + _ = utils.JsonDecode(role.Context, &messages) if chatConfig.ContextDeep > 0 { var historyMessages []model.ChatMessage - res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages) + res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages) if res.Error == nil { for i := len(historyMessages) - 1; i >= 0; i-- { msg := historyMessages[i] - if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) { - break - } - tokens += msg.Tokens ms := types.Message{Role: "user", Content: msg.Content} if msg.Type == types.ReplyMsg { ms.Role = "assistant" @@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } } } + + // 计算当前请求的 token 总长度,确保不会超出最大上下文长度 + // MaxContextLength = Response + Tool + Prompt + Context + tokens := req.MaxTokens // 最大响应长度 + tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model) + tokens += tks + promptTokens + + for _, v := range messages { + tks, _ := utils.CalcTokens(v.Content, req.Model) + // 上下文 token 超出了模型的最大上下文长度 + if tokens+tks >= types.GetModelMaxToken(req.Model) { + break + } + + // 上下文的深度超出了模型的最大上下文深度 + if len(chatCtx) >= h.App.ChatConfig.ContextDeep { + break + } + + tokens += tks + chatCtx = append(chatCtx, v) + } + logger.Debugf("聊天上下文:%+v", chatCtx) } reqMgs := make([]interface{}, 0) diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go index d2ced37d..96329c3c 100644 --- a/api/handler/chatimpl/chatglm_handler.go +++ b/api/handler/chatimpl/chatglm_handler.go @@ -20,7 +20,7 @@ import ( // 清华大学 ChatGML 消息发送实现 func (h *ChatHandler) sendChatGLMMessage( - chatCtx []interface{}, + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, ctx context.Context, diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index 39811946..888f7f40 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -20,7 +20,7 @@ import ( // OPenAI 消息发送实现 func (h *ChatHandler) sendOpenAiMessage( - chatCtx []interface{}, + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, ctx context.Context, @@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage( utils.ReplyMessage(ws, ErrorMsg) utils.ReplyMessage(ws, ErrImg) - all, _ := io.ReadAll(response.Body) - logger.Error(string(all)) + if response.Body != nil { + all, _ := io.ReadAll(response.Body) + logger.Error(string(all)) + } return err } else { defer response.Body.Close() diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go index 116e8788..ccfd8fa4 100644 --- a/api/handler/chatimpl/qwen_handler.go +++ b/api/handler/chatimpl/qwen_handler.go @@ -31,7 +31,7 @@ type qWenResp struct { // 通义千问消息发送实现 func (h *ChatHandler) sendQWenMessage( - chatCtx []interface{}, + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, ctx context.Context, diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index 27e87eed..880dde1c 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -58,7 +58,7 @@ var Model2URL = map[string]string{ // 科大讯飞消息发送实现 func (h *ChatHandler) sendXunFeiMessage( - chatCtx []interface{}, + chatCtx []types.Message, req types.ApiRequest, userVo vo.User, ctx context.Context,