mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	feat: replace Tools param with Function param for OpenAI chat API
This commit is contained in:
		@@ -28,7 +28,7 @@ type AppServer struct {
 | 
			
		||||
	Debug        bool
 | 
			
		||||
	Config       *types.AppConfig
 | 
			
		||||
	Engine       *gin.Engine
 | 
			
		||||
	ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
	ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
 | 
			
		||||
	ChatConfig *types.ChatConfig   // chat config cache
 | 
			
		||||
	SysConfig  *types.SystemConfig // system config cache
 | 
			
		||||
@@ -47,7 +47,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
			
		||||
		Debug:         false,
 | 
			
		||||
		Config:        appConfig,
 | 
			
		||||
		Engine:        gin.Default(),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []interface{}](),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []types.Message](),
 | 
			
		||||
		ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ type MKey interface {
 | 
			
		||||
	string | int | uint
 | 
			
		||||
}
 | 
			
		||||
type MValue interface {
 | 
			
		||||
	*WsClient | *ChatSession | context.CancelFunc | []interface{}
 | 
			
		||||
	*WsClient | *ChatSession | context.CancelFunc | []Message
 | 
			
		||||
}
 | 
			
		||||
type LMap[K MKey, T MValue] struct {
 | 
			
		||||
	lock sync.RWMutex
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,7 @@ import (
 | 
			
		||||
// 微软 Azure 模型消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
 
 | 
			
		||||
@@ -36,7 +36,7 @@ type baiduResp struct {
 | 
			
		||||
// 百度文心一言消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
 
 | 
			
		||||
@@ -224,6 +224,14 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
 | 
			
		||||
	promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
 | 
			
		||||
	if promptTokens > types.GetModelMaxToken(session.Model.Value) {
 | 
			
		||||
		utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var req = types.ApiRequest{
 | 
			
		||||
		Model:  session.Model.Value,
 | 
			
		||||
		Stream: true,
 | 
			
		||||
@@ -252,7 +260,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tools = make([]interface{}, 0)
 | 
			
		||||
		var functions = make([]interface{}, 0)
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var parameters map[string]interface{}
 | 
			
		||||
			err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
			
		||||
@@ -270,20 +277,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
					"required":    required,
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			functions = append(functions, gin.H{
 | 
			
		||||
				"name":        v.Name,
 | 
			
		||||
				"description": v.Description,
 | 
			
		||||
				"parameters":  parameters,
 | 
			
		||||
				"required":    required,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		//if len(tools) > 0 {
 | 
			
		||||
		//	req.Tools = tools
 | 
			
		||||
		//	req.ToolChoice = "auto"
 | 
			
		||||
		//}
 | 
			
		||||
		if len(functions) > 0 {
 | 
			
		||||
			req.Functions = functions
 | 
			
		||||
		if len(tools) > 0 {
 | 
			
		||||
			req.Tools = tools
 | 
			
		||||
			req.ToolChoice = "auto"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
@@ -301,40 +299,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加载聊天上下文
 | 
			
		||||
	var chatCtx []interface{}
 | 
			
		||||
	chatCtx := make([]types.Message, 0)
 | 
			
		||||
	messages := make([]types.Message, 0)
 | 
			
		||||
	if h.App.ChatConfig.EnableContext {
 | 
			
		||||
		if h.App.ChatContexts.Has(session.ChatId) {
 | 
			
		||||
			chatCtx = h.App.ChatContexts.Get(session.ChatId)
 | 
			
		||||
			messages = h.App.ChatContexts.Get(session.ChatId)
 | 
			
		||||
		} else {
 | 
			
		||||
			// calculate the tokens of current request, to prevent to exceeding the max tokens num
 | 
			
		||||
			tokens := req.MaxTokens
 | 
			
		||||
			tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
			
		||||
			tokens += tks
 | 
			
		||||
			// loading the role context
 | 
			
		||||
			var messages []types.Message
 | 
			
		||||
			err := utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				for _, v := range messages {
 | 
			
		||||
					tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
					if tokens+tks >= types.GetModelMaxToken(req.Model) {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
					tokens += tks
 | 
			
		||||
					chatCtx = append(chatCtx, v)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// loading recent chat history as chat context
 | 
			
		||||
			_ = utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
			if chatConfig.ContextDeep > 0 {
 | 
			
		||||
				var historyMessages []model.ChatMessage
 | 
			
		||||
				res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
 | 
			
		||||
				res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					for i := len(historyMessages) - 1; i >= 0; i-- {
 | 
			
		||||
						msg := historyMessages[i]
 | 
			
		||||
						if tokens+msg.Tokens >= types.GetModelMaxToken(session.Model.Value) {
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
						tokens += msg.Tokens
 | 
			
		||||
						ms := types.Message{Role: "user", Content: msg.Content}
 | 
			
		||||
						if msg.Type == types.ReplyMsg {
 | 
			
		||||
							ms.Role = "assistant"
 | 
			
		||||
@@ -344,6 +321,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
 | 
			
		||||
		// MaxContextLength = Response + Tool + Prompt + Context
 | 
			
		||||
		tokens := req.MaxTokens // 最大响应长度
 | 
			
		||||
		tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
			
		||||
		tokens += tks + promptTokens
 | 
			
		||||
 | 
			
		||||
		for _, v := range messages {
 | 
			
		||||
			tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
			// 上下文 token 超出了模型的最大上下文长度
 | 
			
		||||
			if tokens+tks >= types.GetModelMaxToken(req.Model) {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 上下文的深度超出了模型的最大上下文深度
 | 
			
		||||
			if len(chatCtx) >= h.App.ChatConfig.ContextDeep {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tokens += tks
 | 
			
		||||
			chatCtx = append(chatCtx, v)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Debugf("聊天上下文:%+v", chatCtx)
 | 
			
		||||
	}
 | 
			
		||||
	reqMgs := make([]interface{}, 0)
 | 
			
		||||
 
 | 
			
		||||
@@ -20,7 +20,7 @@ import (
 | 
			
		||||
// 清华大学 ChatGML 消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
 
 | 
			
		||||
@@ -20,7 +20,7 @@ import (
 | 
			
		||||
 | 
			
		||||
// OPenAI 消息发送实现
 | 
			
		||||
func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@@ -46,8 +46,10 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		all, _ := io.ReadAll(response.Body)
 | 
			
		||||
		logger.Error(string(all))
 | 
			
		||||
		if response.Body != nil {
 | 
			
		||||
			all, _ := io.ReadAll(response.Body)
 | 
			
		||||
			logger.Error(string(all))
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@ type qWenResp struct {
 | 
			
		||||
 | 
			
		||||
// 通义千问消息发送实现
 | 
			
		||||
func (h *ChatHandler) sendQWenMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
 
 | 
			
		||||
@@ -58,7 +58,7 @@ var Model2URL = map[string]string{
 | 
			
		||||
// 科大讯飞消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user