mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	feat: 支持上下文深度配置,计算每轮对话消耗的总 token 数量
This commit is contained in:
		@@ -72,7 +72,9 @@ type ChatConfig struct {
 | 
			
		||||
	MaxTokens     int     `json:"max_tokens"`
 | 
			
		||||
	EnableContext bool    `json:"enable_context"` // 是否开启聊天上下文
 | 
			
		||||
	EnableHistory bool    `json:"enable_history"` // 是否允许保存聊天记录
 | 
			
		||||
	ApiKey        string  `json:"api_key"`        // OpenAI  API key
 | 
			
		||||
	ApiKey        string  `json:"api_key"`
 | 
			
		||||
	ContextDeep   int     `json:"context_deep"` // 上下文深度
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SystemConfig struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -40,6 +40,8 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, zaoBao *function.FuncZaoBa
 | 
			
		||||
	return &handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var chatConfig types.ChatConfig
 | 
			
		||||
 | 
			
		||||
// ChatHandle 处理聊天 WebSocket 请求
 | 
			
		||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
@@ -84,7 +86,17 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
	var chatRole model.ChatRole
 | 
			
		||||
	res = h.db.First(&chatRole, roleId)
 | 
			
		||||
	if res.Error != nil || !chatRole.Enable {
 | 
			
		||||
		replyMessage(client, "当前聊天角色不存在或者未启用!!!")
 | 
			
		||||
		replyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 初始化聊天配置
 | 
			
		||||
	var config model.Config
 | 
			
		||||
	h.db.Where("marker", "chat").First(&config)
 | 
			
		||||
	err = utils.JsonDecode(config.Config, &chatConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		replyMessage(client, "加载系统配置失败,连接已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -174,16 +186,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
			
		||||
					chatCtx = append(chatCtx, v)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
 | 
			
		||||
			var historyMessages []model.HistoryMessage
 | 
			
		||||
			res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
 | 
			
		||||
			if res.Error == nil {
 | 
			
		||||
				for _, msg := range historyMessages {
 | 
			
		||||
					ms := types.Message{Role: "user", Content: msg.Content}
 | 
			
		||||
					if msg.Type == types.ReplyMsg {
 | 
			
		||||
						ms.Role = "assistant"
 | 
			
		||||
 | 
			
		||||
			// 加载最近的聊天记录作为聊天上下文
 | 
			
		||||
			if chatConfig.ContextDeep > 0 {
 | 
			
		||||
				var historyMessages []model.HistoryMessage
 | 
			
		||||
				res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					for _, msg := range historyMessages {
 | 
			
		||||
						ms := types.Message{Role: "user", Content: msg.Content}
 | 
			
		||||
						if msg.Type == types.ReplyMsg {
 | 
			
		||||
							ms.Role = "assistant"
 | 
			
		||||
						}
 | 
			
		||||
						chatCtx = append(chatCtx, ms)
 | 
			
		||||
					}
 | 
			
		||||
					chatCtx = append(chatCtx, ms)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@@ -324,8 +339,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
			
		||||
				useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				req.Messages = append(req.Messages, message)
 | 
			
		||||
				totalTokens := getTotalTokens(req)
 | 
			
		||||
				var totalTokens = 0
 | 
			
		||||
				if functionCall { // 函数名 + 参数 token
 | 
			
		||||
					tokens, _ := utils.CalcTokens(functionName, req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
				} else {
 | 
			
		||||
					req.Messages = append(req.Messages, message)
 | 
			
		||||
					totalTokens += getTotalTokens(req)
 | 
			
		||||
				}
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
 | 
			
		||||
 | 
			
		||||
				// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
 
 | 
			
		||||
@@ -65,6 +65,14 @@
 | 
			
		||||
          <el-switch v-model="chat['enable_history']"/>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
 | 
			
		||||
        <el-alert type="info" show-icon :closable="false">
 | 
			
		||||
          <p>会话上下文深度:在老会话中继续会话,默认加载多少条聊天记录作为上下文。如果设置为 0
 | 
			
		||||
            则不加载聊天记录,仅仅使用当前角色的上下文。该配置参数最好设置为 2 的整数倍。</p>
 | 
			
		||||
        </el-alert>
 | 
			
		||||
        <el-form-item label="会话上下文深度">
 | 
			
		||||
          <el-input-number v-model="chat['context_deep']" :min="0" :max="10"/>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
 | 
			
		||||
        <el-form-item style="text-align: right">
 | 
			
		||||
          <el-button type="primary" @click="save('chat')">保存</el-button>
 | 
			
		||||
        </el-form-item>
 | 
			
		||||
@@ -178,6 +186,10 @@ const addModel = function () {
 | 
			
		||||
            font-size 16px;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        .tip-text {
 | 
			
		||||
          padding-left 10px;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user