feat: replace Tools param with Function param for OpenAI chat API

This commit is contained in:
RockYang 2024-03-11 14:09:19 +08:00
parent 755273a898
commit 43bfac99b6
9 changed files with 52 additions and 50 deletions

View File

@ -28,7 +28,7 @@ type AppServer struct {
Debug bool
Config *types.AppConfig
Engine *gin.Engine
ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
ChatConfig *types.ChatConfig // chat config cache
SysConfig *types.SystemConfig // system config cache
@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
Debug: false,
Config: appConfig,
Engine: gin.Default(),
ChatContexts: types.NewLMap[string, []interface{}](),
ChatContexts: types.NewLMap[string, []types.Message](),
ChatSession: types.NewLMap[string, *types.ChatSession](),
ChatClients: types.NewLMap[string, *types.WsClient](),
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),

View File

@ -9,7 +9,7 @@ type MKey interface {
string | int | uint
}
type MValue interface {
*WsClient | *ChatSession | context.CancelFunc | []interface{}
*WsClient | *ChatSession | context.CancelFunc | []Message
}
type LMap[K MKey, T MValue] struct {
lock sync.RWMutex

View File

@ -19,7 +19,7 @@ import (
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,

View File

@ -36,7 +36,7 @@ type baiduResp struct {
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,

View File

@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
utils.ReplyMessage(ws, ErrImg)
return nil
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > types.GetModelMaxToken(session.Model.Value) {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: true,
@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var tools = make([]interface{}, 0)
var functions = make([]interface{}, 0)
for _, v := range items {
var parameters map[string]interface{}
err = utils.JsonDecode(v.Parameters, &parameters)
@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
"required": required,
},
})
functions = append(functions, gin.H{
"name": v.Name,
"description": v.Description,
"parameters": parameters,
"required": required,
})
}
//if len(tools) > 0 {
// req.Tools = tools
// req.ToolChoice = "auto"
//}
if len(functions) > 0 {
req.Functions = functions
if len(tools) > 0 {
req.Tools = tools
req.ToolChoice = "auto"
}
case types.XunFei:
@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 加载聊天上下文
var chatCtx []interface{}
chatCtx := make([]types.Message, 0)
messages := make([]types.Message, 0)
if h.App.ChatConfig.EnableContext {
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
messages = h.App.ChatContexts.Get(session.ChatId)
} else {
// calculate the tokens of current request, to prevent to exceeding the max tokens num
tokens := req.MaxTokens
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks
// loading the role context
var messages []types.Message
err := utils.JsonDecode(role.Context, &messages)
if err == nil {
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
}
// loading recent chat history as chat context
_ = utils.JsonDecode(role.Context, &messages)
if chatConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
break
}
tokens += msg.Tokens
ms := types.Message{Role: "user", Content: msg.Content}
if msg.Type == types.ReplyMsg {
ms.Role = "assistant"
@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
}
}
// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
// MaxContextLength = Response + Tool + Prompt + Context
tokens := req.MaxTokens // 最大响应长度
tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
tokens += tks + promptTokens
for _, v := range messages {
tks, _ := utils.CalcTokens(v.Content, req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= types.GetModelMaxToken(req.Model) {
break
}
// 上下文的深度超出了模型的最大上下文深度
if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
break
}
tokens += tks
chatCtx = append(chatCtx, v)
}
logger.Debugf("聊天上下文:%+v", chatCtx)
}
reqMgs := make([]interface{}, 0)

View File

@ -20,7 +20,7 @@ import (
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,

View File

@ -20,7 +20,7 @@ import (
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
if response.Body != nil {
all, _ := io.ReadAll(response.Body)
logger.Error(string(all))
}
return err
} else {
defer response.Body.Close()

View File

@ -31,7 +31,7 @@ type qWenResp struct {
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,

View File

@ -58,7 +58,7 @@ var Model2URL = map[string]string{
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []interface{},
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,