mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: replace Tools param with Function param for OpenAI chat API
This commit is contained in:
parent
755273a898
commit
43bfac99b6
@ -28,7 +28,7 @@ type AppServer struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
Config *types.AppConfig
|
Config *types.AppConfig
|
||||||
Engine *gin.Engine
|
Engine *gin.Engine
|
||||||
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
|
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||||
|
|
||||||
ChatConfig *types.ChatConfig // chat config cache
|
ChatConfig *types.ChatConfig // chat config cache
|
||||||
SysConfig *types.SystemConfig // system config cache
|
SysConfig *types.SystemConfig // system config cache
|
||||||
@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
|
|||||||
Debug: false,
|
Debug: false,
|
||||||
Config: appConfig,
|
Config: appConfig,
|
||||||
Engine: gin.Default(),
|
Engine: gin.Default(),
|
||||||
ChatContexts: types.NewLMap[string, []interface{}](),
|
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||||
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
|
@ -9,7 +9,7 @@ type MKey interface {
|
|||||||
string | int | uint
|
string | int | uint
|
||||||
}
|
}
|
||||||
type MValue interface {
|
type MValue interface {
|
||||||
*WsClient | *ChatSession | context.CancelFunc | []interface{}
|
*WsClient | *ChatSession | context.CancelFunc | []Message
|
||||||
}
|
}
|
||||||
type LMap[K MKey, T MValue] struct {
|
type LMap[K MKey, T MValue] struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
|
@ -19,7 +19,7 @@ import (
|
|||||||
// 微软 Azure 模型消息发送实现
|
// 微软 Azure 模型消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendAzureMessage(
|
func (h *ChatHandler) sendAzureMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -36,7 +36,7 @@ type baiduResp struct {
|
|||||||
// 百度文心一言消息发送实现
|
// 百度文心一言消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendBaiduMessage(
|
func (h *ChatHandler) sendBaiduMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||||
|
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
|
||||||
|
if promptTokens > types.GetModelMaxToken(session.Model.Value) {
|
||||||
|
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var req = types.ApiRequest{
|
var req = types.ApiRequest{
|
||||||
Model: session.Model.Value,
|
Model: session.Model.Value,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
var tools = make([]interface{}, 0)
|
var tools = make([]interface{}, 0)
|
||||||
var functions = make([]interface{}, 0)
|
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var parameters map[string]interface{}
|
var parameters map[string]interface{}
|
||||||
err = utils.JsonDecode(v.Parameters, ¶meters)
|
err = utils.JsonDecode(v.Parameters, ¶meters)
|
||||||
@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
"required": required,
|
"required": required,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
functions = append(functions, gin.H{
|
|
||||||
"name": v.Name,
|
|
||||||
"description": v.Description,
|
|
||||||
"parameters": parameters,
|
|
||||||
"required": required,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
// req.Tools = tools
|
req.Tools = tools
|
||||||
// req.ToolChoice = "auto"
|
req.ToolChoice = "auto"
|
||||||
//}
|
|
||||||
if len(functions) > 0 {
|
|
||||||
req.Functions = functions
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case types.XunFei:
|
case types.XunFei:
|
||||||
@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 加载聊天上下文
|
// 加载聊天上下文
|
||||||
var chatCtx []interface{}
|
chatCtx := make([]types.Message, 0)
|
||||||
|
messages := make([]types.Message, 0)
|
||||||
if h.App.ChatConfig.EnableContext {
|
if h.App.ChatConfig.EnableContext {
|
||||||
if h.App.ChatContexts.Has(session.ChatId) {
|
if h.App.ChatContexts.Has(session.ChatId) {
|
||||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
messages = h.App.ChatContexts.Get(session.ChatId)
|
||||||
} else {
|
} else {
|
||||||
// calculate the tokens of current request, to prevent to exceeding the max tokens num
|
_ = utils.JsonDecode(role.Context, &messages)
|
||||||
tokens := req.MaxTokens
|
|
||||||
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
|
||||||
tokens += tks
|
|
||||||
// loading the role context
|
|
||||||
var messages []types.Message
|
|
||||||
err := utils.JsonDecode(role.Context, &messages)
|
|
||||||
if err == nil {
|
|
||||||
for _, v := range messages {
|
|
||||||
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
|
||||||
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens += tks
|
|
||||||
chatCtx = append(chatCtx, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loading recent chat history as chat context
|
|
||||||
if chatConfig.ContextDeep > 0 {
|
if chatConfig.ContextDeep > 0 {
|
||||||
var historyMessages []model.ChatMessage
|
var historyMessages []model.ChatMessage
|
||||||
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
|
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||||
msg := historyMessages[i]
|
msg := historyMessages[i]
|
||||||
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens += msg.Tokens
|
|
||||||
ms := types.Message{Role: "user", Content: msg.Content}
|
ms := types.Message{Role: "user", Content: msg.Content}
|
||||||
if msg.Type == types.ReplyMsg {
|
if msg.Type == types.ReplyMsg {
|
||||||
ms.Role = "assistant"
|
ms.Role = "assistant"
|
||||||
@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
|
||||||
|
// MaxContextLength = Response + Tool + Prompt + Context
|
||||||
|
tokens := req.MaxTokens // 最大响应长度
|
||||||
|
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
|
||||||
|
tokens += tks + promptTokens
|
||||||
|
|
||||||
|
for _, v := range messages {
|
||||||
|
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
||||||
|
// 上下文 token 超出了模型的最大上下文长度
|
||||||
|
if tokens+tks >= types.GetModelMaxToken(req.Model) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上下文的深度超出了模型的最大上下文深度
|
||||||
|
if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens += tks
|
||||||
|
chatCtx = append(chatCtx, v)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debugf("聊天上下文:%+v", chatCtx)
|
logger.Debugf("聊天上下文:%+v", chatCtx)
|
||||||
}
|
}
|
||||||
reqMgs := make([]interface{}, 0)
|
reqMgs := make([]interface{}, 0)
|
||||||
|
@ -20,7 +20,7 @@ import (
|
|||||||
// 清华大学 ChatGML 消息发送实现
|
// 清华大学 ChatGML 消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendChatGLMMessage(
|
func (h *ChatHandler) sendChatGLMMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
// OPenAI 消息发送实现
|
// OPenAI 消息发送实现
|
||||||
func (h *ChatHandler) sendOpenAiMessage(
|
func (h *ChatHandler) sendOpenAiMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
utils.ReplyMessage(ws, ErrorMsg)
|
||||||
utils.ReplyMessage(ws, ErrImg)
|
utils.ReplyMessage(ws, ErrImg)
|
||||||
all, _ := io.ReadAll(response.Body)
|
if response.Body != nil {
|
||||||
logger.Error(string(all))
|
all, _ := io.ReadAll(response.Body)
|
||||||
|
logger.Error(string(all))
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
@ -31,7 +31,7 @@ type qWenResp struct {
|
|||||||
|
|
||||||
// 通义千问消息发送实现
|
// 通义千问消息发送实现
|
||||||
func (h *ChatHandler) sendQWenMessage(
|
func (h *ChatHandler) sendQWenMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -58,7 +58,7 @@ var Model2URL = map[string]string{
|
|||||||
// 科大讯飞消息发送实现
|
// 科大讯飞消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendXunFeiMessage(
|
func (h *ChatHandler) sendXunFeiMessage(
|
||||||
chatCtx []interface{},
|
chatCtx []types.Message,
|
||||||
req types.ApiRequest,
|
req types.ApiRequest,
|
||||||
userVo vo.User,
|
userVo vo.User,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
Loading…
Reference in New Issue
Block a user