mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: replace Tools param with Function param for OpenAI chat API
This commit is contained in:
parent
755273a898
commit
43bfac99b6
@ -28,7 +28,7 @@ type AppServer struct {
|
||||
Debug bool
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||
|
||||
ChatConfig *types.ChatConfig // chat config cache
|
||||
SysConfig *types.SystemConfig // system config cache
|
||||
@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
Debug: false,
|
||||
Config: appConfig,
|
||||
Engine: gin.Default(),
|
||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
|
@ -9,7 +9,7 @@ type MKey interface {
|
||||
string | int | uint
|
||||
}
|
||||
type MValue interface {
|
||||
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
||||
*WsClient | *ChatSession | context.CancelFunc | []Message
|
||||
}
|
||||
type LMap[K MKey, T MValue] struct {
|
||||
lock sync.RWMutex
|
||||
|
@ -19,7 +19,7 @@ import (
|
||||
// 微软 Azure 模型消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendAzureMessage(
|
||||
chatCtx []interface{},
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
|
@ -36,7 +36,7 @@ type baiduResp struct {
|
||||
// 百度文心一言消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendBaiduMessage(
|
||||
chatCtx []interface{},
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
|
@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||
if promptTokens > types.GetModelMaxToken(session.Model.Value) {
|
||||
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
|
||||
return nil
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: session.Model.Value,
|
||||
Stream: true,
|
||||
@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
var tools = make([]interface{}, 0)
|
||||
var functions = make([]interface{}, 0)
|
||||
for _, v := range items {
|
||||
var parameters map[string]interface{}
|
||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||
@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
"required": required,
|
||||
},
|
||||
})
|
||||
functions = append(functions, gin.H{
|
||||
"name": v.Name,
|
||||
"description": v.Description,
|
||||
"parameters": parameters,
|
||||
"required": required,
|
||||
})
|
||||
}
|
||||
|
||||
//if len(tools) > 0 {
|
||||
// req.Tools = tools
|
||||
// req.ToolChoice = "auto"
|
||||
//}
|
||||
if len(functions) > 0 {
|
||||
req.Functions = functions
|
||||
if len(tools) > 0 {
|
||||
req.Tools = tools
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
case types.XunFei:
|
||||
@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
// 加载聊天上下文
|
||||
var chatCtx []interface{}
|
||||
chatCtx := make([]types.Message, 0)
|
||||
messages := make([]types.Message, 0)
|
||||
if h.App.ChatConfig.EnableContext {
|
||||
if h.App.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
||||
messages = h.App.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
// calculate the tokens of current request, to prevent to exceeding the max tokens num
|
||||
tokens := req.MaxTokens
|
||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||
tokens += tks
|
||||
// loading the role context
|
||||
var messages []types.Message
|
||||
err := utils.JsonDecode(role.Context, &messages)
|
||||
if err == nil {
|
||||
for _, v := range messages {
|
||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
||||
break
|
||||
}
|
||||
tokens += tks
|
||||
chatCtx = append(chatCtx, v)
|
||||
}
|
||||
}
|
||||
|
||||
// loading recent chat history as chat context
|
||||
_ = utils.JsonDecode(role.Context, &messages)
|
||||
if chatConfig.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
|
||||
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||
if res.Error == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
|
||||
break
|
||||
}
|
||||
tokens += msg.Tokens
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
if msg.Type == types.ReplyMsg {
|
||||
ms.Role = "assistant"
|
||||
@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
|
||||
// MaxContextLength = Response + Tool + Prompt + Context
|
||||
tokens := req.MaxTokens // 最大响应长度
|
||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||
tokens += tks + promptTokens
|
||||
|
||||
for _, v := range messages {
|
||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||
// 上下文 token 超出了模型的最大上下文长度
|
||||
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
||||
break
|
||||
}
|
||||
|
||||
// 上下文的深度超出了模型的最大上下文深度
|
||||
if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
|
||||
break
|
||||
}
|
||||
|
||||
tokens += tks
|
||||
chatCtx = append(chatCtx, v)
|
||||
}
|
||||
|
||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||
}
|
||||
reqMgs := make([]interface{}, 0)
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
// 清华大学 ChatGML 消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendChatGLMMessage(
|
||||
chatCtx []interface{},
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
|
||||
// OPenAI 消息发送实现
|
||||
func (h *ChatHandler) sendOpenAiMessage(
|
||||
chatCtx []interface{},
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
||||
|
||||
utils.ReplyMessage(ws, ErrorMsg)
|
||||
utils.ReplyMessage(ws, ErrImg)
|
||||
if response.Body != nil {
|
||||
all, _ := io.ReadAll(response.Body)
|
||||
logger.Error(string(all))
|
||||
}
|
||||
return err
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
|
@ -31,7 +31,7 @@ type qWenResp struct {
|
||||
|
||||
// 通义千问消息发送实现
|
||||
func (h *ChatHandler) sendQWenMessage(
|
||||
chatCtx []interface{},
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
|
@ -58,7 +58,7 @@ var Model2URL = map[string]string{
|
||||
// 科大讯飞消息发送实现
|
||||
|
||||
func (h *ChatHandler) sendXunFeiMessage(
|
||||
chatCtx []interface{},
|
||||
chatCtx []types.Message,
|
||||
req types.ApiRequest,
|
||||
userVo vo.User,
|
||||
ctx context.Context,
|
||||
|
Loading…
Reference in New Issue
Block a user