mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	refactor chat message body struct
This commit is contained in:
		@@ -3,6 +3,8 @@
 | 
				
			|||||||
* 功能优化:用户文件列表组件增加分页功能支持
 | 
					* 功能优化:用户文件列表组件增加分页功能支持
 | 
				
			||||||
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
 | 
					* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
 | 
				
			||||||
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
 | 
					* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
 | 
				
			||||||
 | 
					* 功能新增:给 AI 应用(角色)增加分类
 | 
				
			||||||
 | 
					* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## v4.1.3
 | 
					## v4.1.3
 | 
				
			||||||
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
 | 
					* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,14 +9,14 @@ package types
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ApiRequest API 请求实体
 | 
					// ApiRequest API 请求实体
 | 
				
			||||||
type ApiRequest struct {
 | 
					type ApiRequest struct {
 | 
				
			||||||
	Model       string        `json:"model,omitempty"` // 兼容百度文心一言
 | 
						Model               string        `json:"model,omitempty"`
 | 
				
			||||||
	Temperature float32       `json:"temperature"`
 | 
						Temperature         float32       `json:"temperature"`
 | 
				
			||||||
	MaxTokens   int           `json:"max_tokens,omitempty"` // 兼容百度文心一言
 | 
						MaxTokens           int           `json:"max_tokens,omitempty"`
 | 
				
			||||||
	Stream      bool          `json:"stream"`
 | 
						MaxCompletionTokens int           `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
 | 
				
			||||||
	Messages    []interface{} `json:"messages,omitempty"`
 | 
						Stream              bool          `json:"stream,omitempty"`
 | 
				
			||||||
	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
 | 
						Messages            []interface{} `json:"messages,omitempty"`
 | 
				
			||||||
	Tools       []Tool        `json:"tools,omitempty"`
 | 
						Tools               []Tool        `json:"tools,omitempty"`
 | 
				
			||||||
	Functions   []interface{} `json:"functions,omitempty"` // 兼容中转平台
 | 
						Functions           []interface{} `json:"functions,omitempty"` // 兼容中转平台
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ToolChoice string `json:"tool_choice,omitempty"`
 | 
						ToolChoice string `json:"tool_choice,omitempty"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -57,7 +57,8 @@ type ChatSession struct {
 | 
				
			|||||||
	ClientIP  string    `json:"client_ip"` // 客户端 IP
 | 
						ClientIP  string    `json:"client_ip"` // 客户端 IP
 | 
				
			||||||
	ChatId    string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段
 | 
						ChatId    string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段
 | 
				
			||||||
	Model     ChatModel `json:"model"`     // GPT 模型
 | 
						Model     ChatModel `json:"model"`     // GPT 模型
 | 
				
			||||||
	Tools     string    `json:"tools"`     // 函数
 | 
						Tools     []int     `json:"tools"`     // 工具函数列表
 | 
				
			||||||
 | 
						Stream    bool      `json:"stream"`    // 是否采用流式输出
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ChatModel struct {
 | 
					type ChatModel struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,8 +17,8 @@ type BizVo struct {
 | 
				
			|||||||
	Data     interface{} `json:"data,omitempty"`
 | 
						Data     interface{} `json:"data,omitempty"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// WsMessage Websocket message
 | 
					// ReplyMessage 对话回复消息结构
 | 
				
			||||||
type WsMessage struct {
 | 
					type ReplyMessage struct {
 | 
				
			||||||
	Type    WsMsgType   `json:"type"` // 消息类别,start, end, img
 | 
						Type    WsMsgType   `json:"type"` // 消息类别,start, end, img
 | 
				
			||||||
	Content interface{} `json:"content"`
 | 
						Content interface{} `json:"content"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -32,6 +32,13 @@ const (
 | 
				
			|||||||
	WsErr    = WsMsgType("error")
 | 
						WsErr    = WsMsgType("error")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// InputMessage 对话输入消息结构
 | 
				
			||||||
 | 
					type InputMessage struct {
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
 | 
						Tools   []int  `json:"tools"`  // 允许调用工具列表
 | 
				
			||||||
 | 
						Stream  bool   `json:"stream"` // 是否采用流式输出
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type BizCode int
 | 
					type BizCode int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -73,13 +73,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
				
			|||||||
	roleId := h.GetInt(c, "role_id", 0)
 | 
						roleId := h.GetInt(c, "role_id", 0)
 | 
				
			||||||
	chatId := c.Query("chat_id")
 | 
						chatId := c.Query("chat_id")
 | 
				
			||||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
						modelId := h.GetInt(c, "model_id", 0)
 | 
				
			||||||
	tools := c.Query("tools")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
						client := types.NewWsClient(ws)
 | 
				
			||||||
	var chatRole model.ChatRole
 | 
						var chatRole model.ChatRole
 | 
				
			||||||
	res := h.DB.First(&chatRole, roleId)
 | 
						res := h.DB.First(&chatRole, roleId)
 | 
				
			||||||
	if res.Error != nil || !chatRole.Enable {
 | 
						if res.Error != nil || !chatRole.Enable {
 | 
				
			||||||
		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
 | 
							utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
 | 
				
			||||||
		c.Abort()
 | 
							c.Abort()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -91,7 +90,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
				
			|||||||
	var chatModel model.ChatModel
 | 
						var chatModel model.ChatModel
 | 
				
			||||||
	res = h.DB.First(&chatModel, modelId)
 | 
						res = h.DB.First(&chatModel, modelId)
 | 
				
			||||||
	if res.Error != nil || chatModel.Enabled == false {
 | 
						if res.Error != nil || chatModel.Enabled == false {
 | 
				
			||||||
		utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
 | 
							utils.ReplyErrorMessage(client, "当前AI模型暂未启用,对话已关闭!!!")
 | 
				
			||||||
		c.Abort()
 | 
							c.Abort()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -100,7 +99,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
				
			|||||||
		SessionId: sessionId,
 | 
							SessionId: sessionId,
 | 
				
			||||||
		ClientIP:  c.ClientIP(),
 | 
							ClientIP:  c.ClientIP(),
 | 
				
			||||||
		UserId:    h.GetLoginUserId(c),
 | 
							UserId:    h.GetLoginUserId(c),
 | 
				
			||||||
		Tools:     tools,
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// use old chat data override the chat model and role ID
 | 
						// use old chat data override the chat model and role ID
 | 
				
			||||||
@@ -137,20 +135,16 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var message types.WsMessage
 | 
								var message types.InputMessage
 | 
				
			||||||
			err = utils.JsonDecode(string(msg), &message)
 | 
								err = utils.JsonDecode(string(msg), &message)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// 心跳消息
 | 
								logger.Infof("Receive a message:%+v", message)
 | 
				
			||||||
			if message.Type == "heartbeat" {
 | 
					 | 
				
			||||||
				logger.Debug("收到 Chat 心跳消息:", message.Content)
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			logger.Info("Receive a message: ", message.Content)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								session.Tools = message.Tools
 | 
				
			||||||
 | 
								session.Stream = message.Stream
 | 
				
			||||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
								ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
			h.ReqCancelFunc.Put(sessionId, cancel)
 | 
								h.ReqCancelFunc.Put(sessionId, cancel)
 | 
				
			||||||
			// 回复消息
 | 
								// 回复消息
 | 
				
			||||||
@@ -159,7 +153,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
				
			|||||||
				logger.Error(err)
 | 
									logger.Error(err)
 | 
				
			||||||
				utils.ReplyMessage(client, err.Error())
 | 
									utils.ReplyMessage(client, err.Error())
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
									utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
				
			||||||
				logger.Infof("回答完毕: %v", message.Content)
 | 
									logger.Infof("回答完毕: %v", message.Content)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -208,16 +202,21 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var req = types.ApiRequest{
 | 
						var req = types.ApiRequest{
 | 
				
			||||||
		Model:  session.Model.Value,
 | 
							Model:       session.Model.Value,
 | 
				
			||||||
		Stream: true,
 | 
							Temperature: session.Model.Temperature,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// 兼容 GPT-O1 模型
 | 
				
			||||||
 | 
						if strings.HasPrefix(session.Model.Value, "o1-") {
 | 
				
			||||||
 | 
							req.MaxCompletionTokens = session.Model.MaxTokens
 | 
				
			||||||
 | 
							req.Stream = false
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							req.MaxTokens = session.Model.MaxTokens
 | 
				
			||||||
 | 
							req.Stream = session.Stream
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req.Temperature = session.Model.Temperature
 | 
					 | 
				
			||||||
	req.MaxTokens = session.Model.MaxTokens
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if session.Tools != "" {
 | 
						if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
 | 
				
			||||||
		toolIds := strings.Split(session.Tools, ",")
 | 
					 | 
				
			||||||
		var items []model.Function
 | 
							var items []model.Function
 | 
				
			||||||
		res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
 | 
							res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
 | 
				
			||||||
		if res.Error == nil {
 | 
							if res.Error == nil {
 | 
				
			||||||
			var tools = make([]types.Tool, 0)
 | 
								var tools = make([]types.Tool, 0)
 | 
				
			||||||
			for _, v := range items {
 | 
								for _, v := range items {
 | 
				
			||||||
@@ -279,7 +278,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		for i := len(messages) - 1; i >= 0; i-- {
 | 
							for i := len(messages) - 1; i >= 0; i-- {
 | 
				
			||||||
			v := messages[i]
 | 
								v := messages[i]
 | 
				
			||||||
			tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
								tks, _ = utils.CalcTokens(v.Content, req.Model)
 | 
				
			||||||
			// 上下文 token 超出了模型的最大上下文长度
 | 
								// 上下文 token 超出了模型的最大上下文长度
 | 
				
			||||||
			if tokens+tks >= session.Model.MaxContext {
 | 
								if tokens+tks >= session.Model.MaxContext {
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
@@ -500,10 +499,17 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Usage struct {
 | 
				
			||||||
 | 
						Prompt           string
 | 
				
			||||||
 | 
						Content          string
 | 
				
			||||||
 | 
						PromptTokens     int
 | 
				
			||||||
 | 
						CompletionTokens int
 | 
				
			||||||
 | 
						TotalTokens      int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *ChatHandler) saveChatHistory(
 | 
					func (h *ChatHandler) saveChatHistory(
 | 
				
			||||||
	req types.ApiRequest,
 | 
						req types.ApiRequest,
 | 
				
			||||||
	prompt string,
 | 
						usage Usage,
 | 
				
			||||||
	contents []string,
 | 
					 | 
				
			||||||
	message types.Message,
 | 
						message types.Message,
 | 
				
			||||||
	chatCtx []types.Message,
 | 
						chatCtx []types.Message,
 | 
				
			||||||
	session *types.ChatSession,
 | 
						session *types.ChatSession,
 | 
				
			||||||
@@ -514,8 +520,8 @@ func (h *ChatHandler) saveChatHistory(
 | 
				
			|||||||
	if message.Role == "" {
 | 
						if message.Role == "" {
 | 
				
			||||||
		message.Role = "assistant"
 | 
							message.Role = "assistant"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	message.Content = strings.Join(contents, "")
 | 
						message.Content = usage.Content
 | 
				
			||||||
	useMsg := types.Message{Role: "user", Content: prompt}
 | 
						useMsg := types.Message{Role: "user", Content: usage.Prompt}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
						// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
				
			||||||
	if h.App.SysConfig.EnableContext {
 | 
						if h.App.SysConfig.EnableContext {
 | 
				
			||||||
@@ -526,42 +532,52 @@ func (h *ChatHandler) saveChatHistory(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// 追加聊天记录
 | 
						// 追加聊天记录
 | 
				
			||||||
	// for prompt
 | 
						// for prompt
 | 
				
			||||||
	promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
						var promptTokens, replyTokens, totalTokens int
 | 
				
			||||||
	if err != nil {
 | 
						if usage.PromptTokens > 0 {
 | 
				
			||||||
		logger.Error(err)
 | 
							promptTokens = usage.PromptTokens
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	historyUserMsg := model.ChatMessage{
 | 
						historyUserMsg := model.ChatMessage{
 | 
				
			||||||
		UserId:     userVo.Id,
 | 
							UserId:      userVo.Id,
 | 
				
			||||||
		ChatId:     session.ChatId,
 | 
							ChatId:      session.ChatId,
 | 
				
			||||||
		RoleId:     role.Id,
 | 
							RoleId:      role.Id,
 | 
				
			||||||
		Type:       types.PromptMsg,
 | 
							Type:        types.PromptMsg,
 | 
				
			||||||
		Icon:       userVo.Avatar,
 | 
							Icon:        userVo.Avatar,
 | 
				
			||||||
		Content:    template.HTMLEscapeString(prompt),
 | 
							Content:     template.HTMLEscapeString(usage.Prompt),
 | 
				
			||||||
		Tokens:     promptToken,
 | 
							Tokens:      promptTokens,
 | 
				
			||||||
		UseContext: true,
 | 
							TotalTokens: promptTokens,
 | 
				
			||||||
		Model:      req.Model,
 | 
							UseContext:  true,
 | 
				
			||||||
 | 
							Model:       req.Model,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	historyUserMsg.CreatedAt = promptCreatedAt
 | 
						historyUserMsg.CreatedAt = promptCreatedAt
 | 
				
			||||||
	historyUserMsg.UpdatedAt = promptCreatedAt
 | 
						historyUserMsg.UpdatedAt = promptCreatedAt
 | 
				
			||||||
	err = h.DB.Save(&historyUserMsg).Error
 | 
						err := h.DB.Save(&historyUserMsg).Error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Error("failed to save prompt history message: ", err)
 | 
							logger.Error("failed to save prompt history message: ", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// for reply
 | 
						// for reply
 | 
				
			||||||
	// 计算本次对话消耗的总 token 数量
 | 
						// 计算本次对话消耗的总 token 数量
 | 
				
			||||||
	replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
						if usage.CompletionTokens > 0 {
 | 
				
			||||||
	totalTokens := replyTokens + getTotalTokens(req)
 | 
							replyTokens = usage.CompletionTokens
 | 
				
			||||||
 | 
							totalTokens = usage.TotalTokens
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
				
			||||||
 | 
							totalTokens = replyTokens + getTotalTokens(req)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	historyReplyMsg := model.ChatMessage{
 | 
						historyReplyMsg := model.ChatMessage{
 | 
				
			||||||
		UserId:     userVo.Id,
 | 
							UserId:      userVo.Id,
 | 
				
			||||||
		ChatId:     session.ChatId,
 | 
							ChatId:      session.ChatId,
 | 
				
			||||||
		RoleId:     role.Id,
 | 
							RoleId:      role.Id,
 | 
				
			||||||
		Type:       types.ReplyMsg,
 | 
							Type:        types.ReplyMsg,
 | 
				
			||||||
		Icon:       role.Icon,
 | 
							Icon:        role.Icon,
 | 
				
			||||||
		Content:    message.Content,
 | 
							Content:     message.Content,
 | 
				
			||||||
		Tokens:     totalTokens,
 | 
							Tokens:      replyTokens,
 | 
				
			||||||
		UseContext: true,
 | 
							TotalTokens: totalTokens,
 | 
				
			||||||
		Model:      req.Model,
 | 
							UseContext:  true,
 | 
				
			||||||
 | 
							Model:       req.Model,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	historyReplyMsg.CreatedAt = replyCreatedAt
 | 
						historyReplyMsg.CreatedAt = replyCreatedAt
 | 
				
			||||||
	historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
						historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
				
			||||||
@@ -572,7 +588,7 @@ func (h *ChatHandler) saveChatHistory(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// 更新用户算力
 | 
						// 更新用户算力
 | 
				
			||||||
	if session.Model.Power > 0 {
 | 
						if session.Model.Power > 0 {
 | 
				
			||||||
		h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
							h.subUserPower(userVo, session, promptTokens, replyTokens)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// 保存当前会话
 | 
						// 保存当前会话
 | 
				
			||||||
	var chatItem model.ChatItem
 | 
						var chatItem model.ChatItem
 | 
				
			||||||
@@ -582,10 +598,10 @@ func (h *ChatHandler) saveChatHistory(
 | 
				
			|||||||
		chatItem.UserId = userVo.Id
 | 
							chatItem.UserId = userVo.Id
 | 
				
			||||||
		chatItem.RoleId = role.Id
 | 
							chatItem.RoleId = role.Id
 | 
				
			||||||
		chatItem.ModelId = session.Model.Id
 | 
							chatItem.ModelId = session.Model.Id
 | 
				
			||||||
		if utf8.RuneCountInString(prompt) > 30 {
 | 
							if utf8.RuneCountInString(usage.Prompt) > 30 {
 | 
				
			||||||
			chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
								chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			chatItem.Title = prompt
 | 
								chatItem.Title = usage.Prompt
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		chatItem.Model = req.Model
 | 
							chatItem.Model = req.Model
 | 
				
			||||||
		err = h.DB.Create(&chatItem).Error
 | 
							err = h.DB.Create(&chatItem).Error
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,6 +23,28 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type respVo struct {
 | 
				
			||||||
 | 
						Id                string `json:"id"`
 | 
				
			||||||
 | 
						Object            string `json:"object"`
 | 
				
			||||||
 | 
						Created           int    `json:"created"`
 | 
				
			||||||
 | 
						Model             string `json:"model"`
 | 
				
			||||||
 | 
						SystemFingerprint string `json:"system_fingerprint"`
 | 
				
			||||||
 | 
						Choices           []struct {
 | 
				
			||||||
 | 
							Index   int `json:"index"`
 | 
				
			||||||
 | 
							Message struct {
 | 
				
			||||||
 | 
								Role    string `json:"role"`
 | 
				
			||||||
 | 
								Content string `json:"content"`
 | 
				
			||||||
 | 
							} `json:"message"`
 | 
				
			||||||
 | 
							Logprobs     interface{} `json:"logprobs"`
 | 
				
			||||||
 | 
							FinishReason string      `json:"finish_reason"`
 | 
				
			||||||
 | 
						} `json:"choices"`
 | 
				
			||||||
 | 
						Usage struct {
 | 
				
			||||||
 | 
							PromptTokens     int `json:"prompt_tokens"`
 | 
				
			||||||
 | 
							CompletionTokens int `json:"completion_tokens"`
 | 
				
			||||||
 | 
							TotalTokens      int `json:"total_tokens"`
 | 
				
			||||||
 | 
						} `json:"usage"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// OPenAI 消息发送实现
 | 
					// OPenAI 消息发送实现
 | 
				
			||||||
func (h *ChatHandler) sendOpenAiMessage(
 | 
					func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			||||||
	chatCtx []types.Message,
 | 
						chatCtx []types.Message,
 | 
				
			||||||
@@ -49,6 +71,10 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
		defer response.Body.Close()
 | 
							defer response.Body.Close()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if response.StatusCode != 200 {
 | 
				
			||||||
 | 
							body, _ := io.ReadAll(response.Body)
 | 
				
			||||||
 | 
							return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	contentType := response.Header.Get("Content-Type")
 | 
						contentType := response.Header.Get("Content-Type")
 | 
				
			||||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
						if strings.Contains(contentType, "text/event-stream") {
 | 
				
			||||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
							replyCreatedAt := time.Now() // 记录回复时间
 | 
				
			||||||
@@ -106,8 +132,8 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
				if res.Error == nil {
 | 
									if res.Error == nil {
 | 
				
			||||||
					toolCall = true
 | 
										toolCall = true
 | 
				
			||||||
					callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
										callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
				
			||||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
										utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
 | 
				
			||||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
 | 
										utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg})
 | 
				
			||||||
					contents = append(contents, callMsg)
 | 
										contents = append(contents, callMsg)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
@@ -125,10 +151,10 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
				content := responseBody.Choices[0].Delta.Content
 | 
									content := responseBody.Choices[0].Delta.Content
 | 
				
			||||||
				contents = append(contents, utils.InterfaceToString(content))
 | 
									contents = append(contents, utils.InterfaceToString(content))
 | 
				
			||||||
				if isNew {
 | 
									if isNew {
 | 
				
			||||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
										utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
 | 
				
			||||||
					isNew = false
 | 
										isNew = false
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
									utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
				
			||||||
					Type:    types.WsMiddle,
 | 
										Type:    types.WsMiddle,
 | 
				
			||||||
					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
										Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
@@ -161,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			if errMsg != "" || apiRes.Code != types.Success {
 | 
								if errMsg != "" || apiRes.Code != types.Success {
 | 
				
			||||||
				msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
									msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
				
			||||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
									utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
				
			||||||
					Type:    types.WsMiddle,
 | 
										Type:    types.WsMiddle,
 | 
				
			||||||
					Content: msg,
 | 
										Content: msg,
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
				contents = append(contents, msg)
 | 
									contents = append(contents, msg)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
									utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
				
			||||||
					Type:    types.WsMiddle,
 | 
										Type:    types.WsMiddle,
 | 
				
			||||||
					Content: apiRes.Data,
 | 
										Content: apiRes.Data,
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
@@ -177,11 +203,17 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		// 消息发送成功
 | 
							// 消息发送成功
 | 
				
			||||||
		if len(contents) > 0 {
 | 
							if len(contents) > 0 {
 | 
				
			||||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
								usage := Usage{
 | 
				
			||||||
 | 
									Prompt:           prompt,
 | 
				
			||||||
 | 
									Content:          strings.Join(contents, ""),
 | 
				
			||||||
 | 
									PromptTokens:     0,
 | 
				
			||||||
 | 
									CompletionTokens: 0,
 | 
				
			||||||
 | 
									TotalTokens:      0,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else { // 非流式输出
 | 
				
			||||||
		body, _ := io.ReadAll(response.Body)
 | 
							
 | 
				
			||||||
		return fmt.Errorf("请求 OpenAI API 失败:%s", body)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -73,7 +73,7 @@ func (h *DallJobHandler) Client(c *gin.Context) {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var message types.WsMessage
 | 
								var message types.ReplyMessage
 | 
				
			||||||
			err = utils.JsonDecode(string(msg), &message)
 | 
								err = utils.JsonDecode(string(msg), &message)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -64,7 +64,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var message types.WsMessage
 | 
								var message types.ReplyMessage
 | 
				
			||||||
			err = utils.JsonDecode(string(msg), &message)
 | 
								err = utils.JsonDecode(string(msg), &message)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
@@ -85,7 +85,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
				
			|||||||
			err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
 | 
								err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				logger.Error(err)
 | 
									logger.Error(err)
 | 
				
			||||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
 | 
									utils.ReplyErrorMessage(client, err.Error())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -170,16 +170,16 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if isNew {
 | 
								if isNew {
 | 
				
			||||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
 | 
									utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart})
 | 
				
			||||||
				isNew = false
 | 
									isNew = false
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			utils.ReplyChunkMessage(client, types.WsMessage{
 | 
								utils.ReplyChunkMessage(client, types.ReplyMessage{
 | 
				
			||||||
				Type:    types.WsMiddle,
 | 
									Type:    types.WsMiddle,
 | 
				
			||||||
				Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
									Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		} // end for
 | 
							} // end for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
							utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		body, _ := io.ReadAll(response.Body)
 | 
							body, _ := io.ReadAll(response.Body)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,16 +4,17 @@ import "gorm.io/gorm"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type ChatMessage struct {
 | 
					type ChatMessage struct {
 | 
				
			||||||
	BaseModel
 | 
						BaseModel
 | 
				
			||||||
	ChatId     string // 会话 ID
 | 
						ChatId      string // 会话 ID
 | 
				
			||||||
	UserId     uint   // 用户 ID
 | 
						UserId      uint   // 用户 ID
 | 
				
			||||||
	RoleId     uint   // 角色 ID
 | 
						RoleId      uint   // 角色 ID
 | 
				
			||||||
	Model      string // AI模型
 | 
						Model       string // AI模型
 | 
				
			||||||
	Type       string
 | 
						Type        string
 | 
				
			||||||
	Icon       string
 | 
						Icon        string
 | 
				
			||||||
	Tokens     int
 | 
						Tokens      int
 | 
				
			||||||
	Content    string
 | 
						TotalTokens int // 总 token 消耗
 | 
				
			||||||
	UseContext bool // 是否可以作为聊天上下文
 | 
						Content     string
 | 
				
			||||||
	DeletedAt  gorm.DeletedAt
 | 
						UseContext  bool // 是否可以作为聊天上下文
 | 
				
			||||||
 | 
						DeletedAt   gorm.DeletedAt
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (ChatMessage) TableName() string {
 | 
					func (ChatMessage) TableName() string {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,7 @@ import "geekai/core/types"
 | 
				
			|||||||
type ChatRole struct {
 | 
					type ChatRole struct {
 | 
				
			||||||
	BaseVo
 | 
						BaseVo
 | 
				
			||||||
	Key       string          `json:"key"` // 角色唯一标识
 | 
						Key       string          `json:"key"` // 角色唯一标识
 | 
				
			||||||
	Tid       uint            `json:"tid"`
 | 
						Tid       int             `json:"tid"`
 | 
				
			||||||
	Name      string          `json:"name"`       // 角色名称
 | 
						Name      string          `json:"name"`       // 角色名称
 | 
				
			||||||
	Context   []types.Message `json:"context"`    // 角色语料信息
 | 
						Context   []types.Message `json:"context"`    // 角色语料信息
 | 
				
			||||||
	HelloMsg  string          `json:"hello_msg"`  // 打招呼的消息
 | 
						HelloMsg  string          `json:"hello_msg"`  // 打招呼的消息
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -33,9 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ReplyMessage 回复客户端一条完整的消息
 | 
					// ReplyMessage 回复客户端一条完整的消息
 | 
				
			||||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
 | 
					func ReplyMessage(ws *types.WsClient, message interface{}) {
 | 
				
			||||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
						ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
 | 
				
			||||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
 | 
						ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message})
 | 
				
			||||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
 | 
						ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ReplyErrorMessage 向客户端发送错误消息
 | 
				
			||||||
 | 
					func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
 | 
				
			||||||
 | 
						ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
					func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,4 +9,6 @@ CREATE TABLE `chatgpt_app_types` (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
ALTER TABLE `chatgpt_app_types`ADD PRIMARY KEY (`id`);
 | 
					ALTER TABLE `chatgpt_app_types`ADD PRIMARY KEY (`id`);
 | 
				
			||||||
ALTER TABLE `chatgpt_app_types` MODIFY `id` int NOT NULL AUTO_INCREMENT;
 | 
					ALTER TABLE `chatgpt_app_types` MODIFY `id` int NOT NULL AUTO_INCREMENT;
 | 
				
			||||||
ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`;
 | 
					ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ALTER TABLE `chatgpt_chat_history` ADD `total_tokens` INT NOT NULL COMMENT '消耗总Token长度' AFTER `tokens`;
 | 
				
			||||||
		Reference in New Issue
	
	Block a user