mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	feat: replace Tools param with Function param for OpenAI chat API
This commit is contained in:
		@@ -28,7 +28,7 @@ type AppServer struct {
 | 
				
			|||||||
	Debug        bool
 | 
						Debug        bool
 | 
				
			||||||
	Config       *types.AppConfig
 | 
						Config       *types.AppConfig
 | 
				
			||||||
	Engine       *gin.Engine
 | 
						Engine       *gin.Engine
 | 
				
			||||||
	ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
 | 
						ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ChatConfig *types.ChatConfig   // chat config cache
 | 
						ChatConfig *types.ChatConfig   // chat config cache
 | 
				
			||||||
	SysConfig  *types.SystemConfig // system config cache
 | 
						SysConfig  *types.SystemConfig // system config cache
 | 
				
			||||||
@@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
				
			|||||||
		Debug:         false,
 | 
							Debug:         false,
 | 
				
			||||||
		Config:        appConfig,
 | 
							Config:        appConfig,
 | 
				
			||||||
		Engine:        gin.Default(),
 | 
							Engine:        gin.Default(),
 | 
				
			||||||
		ChatContexts:  types.NewLMap[string, []interface{}](),
 | 
							ChatContexts:  types.NewLMap[string, []types.Message](),
 | 
				
			||||||
		ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
							ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
				
			||||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
							ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
				
			||||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
							ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,7 +9,7 @@ type MKey interface {
 | 
				
			|||||||
	string | int | uint
 | 
						string | int | uint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
type MValue interface {
 | 
					type MValue interface {
 | 
				
			||||||
	*WsClient | *ChatSession | context.CancelFunc | []interface{}
 | 
						*WsClient | *ChatSession | context.CancelFunc | []Message
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
type LMap[K MKey, T MValue] struct {
 | 
					type LMap[K MKey, T MValue] struct {
 | 
				
			||||||
	lock sync.RWMutex
 | 
						lock sync.RWMutex
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,7 +19,7 @@ import (
 | 
				
			|||||||
// 微软 Azure 模型消息发送实现
 | 
					// 微软 Azure 模型消息发送实现
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *ChatHandler) sendAzureMessage(
 | 
					func (h *ChatHandler) sendAzureMessage(
 | 
				
			||||||
	chatCtx []interface{},
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	userVo vo.User,
 | 
						userVo vo.User,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -36,7 +36,7 @@ type baiduResp struct {
 | 
				
			|||||||
// 百度文心一言消息发送实现
 | 
					// 百度文心一言消息发送实现
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *ChatHandler) sendBaiduMessage(
 | 
					func (h *ChatHandler) sendBaiduMessage(
 | 
				
			||||||
	chatCtx []interface{},
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	userVo vo.User,
 | 
						userVo vo.User,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
		utils.ReplyMessage(ws, ErrImg)
 | 
							utils.ReplyMessage(ws, ErrImg)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
 | 
				
			||||||
 | 
						promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
 | 
				
			||||||
 | 
						if promptTokens > types.GetModelMaxToken(session.Model.Value) {
 | 
				
			||||||
 | 
							utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var req = types.ApiRequest{
 | 
						var req = types.ApiRequest{
 | 
				
			||||||
		Model:  session.Model.Value,
 | 
							Model:  session.Model.Value,
 | 
				
			||||||
		Stream: true,
 | 
							Stream: true,
 | 
				
			||||||
@@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		var tools = make([]interface{}, 0)
 | 
							var tools = make([]interface{}, 0)
 | 
				
			||||||
		var functions = make([]interface{}, 0)
 | 
					 | 
				
			||||||
		for _, v := range items {
 | 
							for _, v := range items {
 | 
				
			||||||
			var parameters map[string]interface{}
 | 
								var parameters map[string]interface{}
 | 
				
			||||||
			err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
								err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
				
			||||||
@@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
					"required":    required,
 | 
										"required":    required,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			functions = append(functions, gin.H{
 | 
					 | 
				
			||||||
				"name":        v.Name,
 | 
					 | 
				
			||||||
				"description": v.Description,
 | 
					 | 
				
			||||||
				"parameters":  parameters,
 | 
					 | 
				
			||||||
				"required":    required,
 | 
					 | 
				
			||||||
			})
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		//if len(tools) > 0 {
 | 
							if len(tools) > 0 {
 | 
				
			||||||
		//	req.Tools = tools
 | 
								req.Tools = tools
 | 
				
			||||||
		//	req.ToolChoice = "auto"
 | 
								req.ToolChoice = "auto"
 | 
				
			||||||
		//}
 | 
					 | 
				
			||||||
		if len(functions) > 0 {
 | 
					 | 
				
			||||||
			req.Functions = functions
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	case types.XunFei:
 | 
						case types.XunFei:
 | 
				
			||||||
@@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 加载聊天上下文
 | 
						// 加载聊天上下文
 | 
				
			||||||
	var chatCtx []interface{}
 | 
						chatCtx := make([]types.Message, 0)
 | 
				
			||||||
 | 
						messages := make([]types.Message, 0)
 | 
				
			||||||
	if h.App.ChatConfig.EnableContext {
 | 
						if h.App.ChatConfig.EnableContext {
 | 
				
			||||||
		if h.App.ChatContexts.Has(session.ChatId) {
 | 
							if h.App.ChatContexts.Has(session.ChatId) {
 | 
				
			||||||
			chatCtx = h.App.ChatContexts.Get(session.ChatId)
 | 
								messages = h.App.ChatContexts.Get(session.ChatId)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			// calculate the tokens of current request, to prevent to exceeding the max tokens num
 | 
								_ = utils.JsonDecode(role.Context, &messages)
 | 
				
			||||||
			tokens := req.MaxTokens
 | 
					 | 
				
			||||||
			tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
					 | 
				
			||||||
			tokens += tks
 | 
					 | 
				
			||||||
			// loading the role context
 | 
					 | 
				
			||||||
			var messages []types.Message
 | 
					 | 
				
			||||||
			err := utils.JsonDecode(role.Context, &messages)
 | 
					 | 
				
			||||||
			if err == nil {
 | 
					 | 
				
			||||||
				for _, v := range messages {
 | 
					 | 
				
			||||||
					tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
					 | 
				
			||||||
					if tokens+tks >= types.GetModelMaxToken(req.Model) {
 | 
					 | 
				
			||||||
						break
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					tokens += tks
 | 
					 | 
				
			||||||
					chatCtx = append(chatCtx, v)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// loading recent chat history as chat context
 | 
					 | 
				
			||||||
			if chatConfig.ContextDeep > 0 {
 | 
								if chatConfig.ContextDeep > 0 {
 | 
				
			||||||
				var historyMessages []model.ChatMessage
 | 
									var historyMessages []model.ChatMessage
 | 
				
			||||||
				res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
 | 
									res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
 | 
				
			||||||
				if res.Error == nil {
 | 
									if res.Error == nil {
 | 
				
			||||||
					for i := len(historyMessages) - 1; i >= 0; i-- {
 | 
										for i := len(historyMessages) - 1; i >= 0; i-- {
 | 
				
			||||||
						msg := historyMessages[i]
 | 
											msg := historyMessages[i]
 | 
				
			||||||
						if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
 | 
					 | 
				
			||||||
							break
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
						tokens += msg.Tokens
 | 
					 | 
				
			||||||
						ms := types.Message{Role: "user", Content: msg.Content}
 | 
											ms := types.Message{Role: "user", Content: msg.Content}
 | 
				
			||||||
						if msg.Type == types.ReplyMsg {
 | 
											if msg.Type == types.ReplyMsg {
 | 
				
			||||||
							ms.Role = "assistant"
 | 
												ms.Role = "assistant"
 | 
				
			||||||
@@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
 | 
				
			||||||
 | 
							// MaxContextLength = Response + Tool + Prompt + Context
 | 
				
			||||||
 | 
							tokens := req.MaxTokens // 最大响应长度
 | 
				
			||||||
 | 
							tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
				
			||||||
 | 
							tokens += tks + promptTokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, v := range messages {
 | 
				
			||||||
 | 
								tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
				
			||||||
 | 
								// 上下文 token 超出了模型的最大上下文长度
 | 
				
			||||||
 | 
								if tokens+tks >= types.GetModelMaxToken(req.Model) {
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// 上下文的深度超出了模型的最大上下文深度
 | 
				
			||||||
 | 
								if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								tokens += tks
 | 
				
			||||||
 | 
								chatCtx = append(chatCtx, v)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		logger.Debugf("聊天上下文:%+v", chatCtx)
 | 
							logger.Debugf("聊天上下文:%+v", chatCtx)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	reqMgs := make([]interface{}, 0)
 | 
						reqMgs := make([]interface{}, 0)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,7 +20,7 @@ import (
 | 
				
			|||||||
// 清华大学 ChatGML 消息发送实现
 | 
					// 清华大学 ChatGML 消息发送实现
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *ChatHandler) sendChatGLMMessage(
 | 
					func (h *ChatHandler) sendChatGLMMessage(
 | 
				
			||||||
	chatCtx []interface{},
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	userVo vo.User,
 | 
						userVo vo.User,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,7 +20,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// OPenAI 消息发送实现
 | 
					// OPenAI 消息发送实现
 | 
				
			||||||
func (h *ChatHandler) sendOpenAiMessage(
 | 
					func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			||||||
	chatCtx []interface{},
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	userVo vo.User,
 | 
						userVo vo.User,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
@@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
							utils.ReplyMessage(ws, ErrorMsg)
 | 
				
			||||||
		utils.ReplyMessage(ws, ErrImg)
 | 
							utils.ReplyMessage(ws, ErrImg)
 | 
				
			||||||
		all, _ := io.ReadAll(response.Body)
 | 
							if response.Body != nil {
 | 
				
			||||||
		logger.Error(string(all))
 | 
								all, _ := io.ReadAll(response.Body)
 | 
				
			||||||
 | 
								logger.Error(string(all))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		defer response.Body.Close()
 | 
							defer response.Body.Close()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -31,7 +31,7 @@ type qWenResp struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// 通义千问消息发送实现
 | 
					// 通义千问消息发送实现
 | 
				
			||||||
func (h *ChatHandler) sendQWenMessage(
 | 
					func (h *ChatHandler) sendQWenMessage(
 | 
				
			||||||
	chatCtx []interface{},
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	userVo vo.User,
 | 
						userVo vo.User,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -58,7 +58,7 @@ var Model2URL = map[string]string{
 | 
				
			|||||||
// 科大讯飞消息发送实现
 | 
					// 科大讯飞消息发送实现
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *ChatHandler) sendXunFeiMessage(
 | 
					func (h *ChatHandler) sendXunFeiMessage(
 | 
				
			||||||
	chatCtx []interface{},
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	userVo vo.User,
 | 
						userVo vo.User,
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user