feat: replace Tools param with Function param for OpenAI chat API

This commit is contained in:
RockYang 2024-03-11 14:09:19 +08:00
parent 755273a898
commit 43bfac99b6
9 changed files with 52 additions and 50 deletions

View File

@ -28,7 +28,7 @@ type AppServer struct {
Debug bool Debug bool
Config *types.AppConfig Config *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache ChatConfig *types.ChatConfig // chat config cache
SysConfig *types.SystemConfig // system config cache SysConfig *types.SystemConfig // system config cache
@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false, Debug: false,
Config: appConfig, Config: appConfig,
Engine: gin.Default(), Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []interface{}](), ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](), ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](), ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),

View File

@ -9,7 +9,7 @@ type MKey interface {
string | int | uint string | int | uint
} }
type MValue interface { type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{} *WsClient | *ChatSession | context.CancelFunc | []Message
} }
type LMap[K MKey, T MValue] struct { type LMap[K MKey, T MValue] struct {
lock sync.RWMutex lock sync.RWMutex

View File

@ -19,7 +19,7 @@ import (
// 微软 Azure 模型消息发送实现 // 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage( func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,

View File

@ -36,7 +36,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现 // 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage( func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,

View File

@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg) utils.ReplyMessage(ws, ErrImg)
return nil return nil
} }
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > types.GetModelMaxToken(session.Model.Value) {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{ var req = types.ApiRequest{
Model: session.Model.Value, Model: session.Model.Value,
Stream: true, Stream: true,
@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
var tools = make([]interface{}, 0) var tools = make([]interface{}, 0)
var functions = make([]interface{}, 0)
for _, v := range items { for _, v := range items {
var parameters map[string]interface{} var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters) err = utils.JsonDecode(v.Parameters, &parameters)
@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"required": required, "required": required,
}, },
}) })
functions = append(functions, gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
})
} }
//if len(tools) > 0 { if len(tools) > 0 {
// req.Tools = tools req.Tools = tools
// req.ToolChoice = "auto" req.ToolChoice = "auto"
//}
if len(functions) > 0 {
req.Functions = functions
} }
case types.XunFei: case types.XunFei:
@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
// 加载聊天上下文 // 加载聊天上下文
var chatCtx []interface{} chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.ChatConfig.EnableContext { if h.App.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) { if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId) messages = h.App.ChatContexts.Get(session.ChatId)
} else { } else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num _ = utils.JsonDecode(role.Context, &messages)
tokens := req.MaxTokens
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
if chatConfig.ContextDeep > 0 { if chatConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage var historyMessages []model.ChatMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages) res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil { if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- { for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i] msg := historyMessages[i]
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content} ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg { if msg.Type == types.ReplyMsg {
ms.Role = "assistant" ms.Role = "assistant"
@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
} }
} }
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx) logger.Debugf("聊天上下文:%+v", chatCtx)
} }
reqMgs := make([]interface{}, 0) reqMgs := make([]interface{}, 0)

View File

@ -20,7 +20,7 @@ import (
// 清华大学 ChatGML 消息发送实现 // 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage( func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,

View File

@ -20,7 +20,7 @@ import (
// OPenAI 消息发送实现 // OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage( func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,
@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrorMsg) utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg) utils.ReplyMessage(ws, ErrImg)
all, _ := io.ReadAll(response.Body) if response.Body != nil {
logger.Error(string(all)) all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
}
return err return err
} else { } else {
defer response.Body.Close() defer response.Body.Close()

View File

@ -31,7 +31,7 @@ type qWenResp struct {
// 通义千问消息发送实现 // 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage( func (h *ChatHandler) sendQWenMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,

View File

@ -58,7 +58,7 @@ var Model2URL = map[string]string{
// 科大讯飞消息发送实现 // 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage( func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{}, chatCtx []types.Message,
req types.ApiRequest, req types.ApiRequest,
userVo vo.User, userVo vo.User,
ctx context.Context, ctx context.Context,