mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	refactor chat message body struct
This commit is contained in:
		@@ -3,6 +3,8 @@
 | 
			
		||||
* 功能优化:用户文件列表组件增加分页功能支持
 | 
			
		||||
* Bug修复:修复用户注册失败Bug,注册操作只弹出一次行为验证码
 | 
			
		||||
* 功能优化:首次登录不需要验证码,直接登录,登录失败之后才弹出验证码
 | 
			
		||||
* 功能新增:给 AI 应用(角色)增加分类
 | 
			
		||||
* 功能优化:允许用户在聊天页面设置是否使用流式输出或者一次性输出,兼容 GPT-O1 模型。
 | 
			
		||||
 | 
			
		||||
## v4.1.3
 | 
			
		||||
* 功能优化:重构用户登录模块,给所有的登录组件增加行为验证码功能,支持用户绑定手机,邮箱和微信
 | 
			
		||||
 
 | 
			
		||||
@@ -9,12 +9,12 @@ package types
 | 
			
		||||
 | 
			
		||||
// ApiRequest API 请求实体
 | 
			
		||||
type ApiRequest struct {
 | 
			
		||||
	Model       string        `json:"model,omitempty"` // 兼容百度文心一言
 | 
			
		||||
	Model               string        `json:"model,omitempty"`
 | 
			
		||||
	Temperature         float32       `json:"temperature"`
 | 
			
		||||
	MaxTokens   int           `json:"max_tokens,omitempty"` // 兼容百度文心一言
 | 
			
		||||
	Stream      bool          `json:"stream"`
 | 
			
		||||
	MaxTokens           int           `json:"max_tokens,omitempty"`
 | 
			
		||||
	MaxCompletionTokens int           `json:"max_completion_tokens,omitempty"` // 兼容GPT O1 模型
 | 
			
		||||
	Stream              bool          `json:"stream,omitempty"`
 | 
			
		||||
	Messages            []interface{} `json:"messages,omitempty"`
 | 
			
		||||
	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
 | 
			
		||||
	Tools               []Tool        `json:"tools,omitempty"`
 | 
			
		||||
	Functions           []interface{} `json:"functions,omitempty"` // 兼容中转平台
 | 
			
		||||
 | 
			
		||||
@@ -57,7 +57,8 @@ type ChatSession struct {
 | 
			
		||||
	ClientIP  string    `json:"client_ip"` // 客户端 IP
 | 
			
		||||
	ChatId    string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段
 | 
			
		||||
	Model     ChatModel `json:"model"`     // GPT 模型
 | 
			
		||||
	Tools     string    `json:"tools"`     // 函数
 | 
			
		||||
	Tools     []int     `json:"tools"`     // 工具函数列表
 | 
			
		||||
	Stream    bool      `json:"stream"`    // 是否采用流式输出
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatModel struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -17,8 +17,8 @@ type BizVo struct {
 | 
			
		||||
	Data     interface{} `json:"data,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WsMessage Websocket message
 | 
			
		||||
type WsMessage struct {
 | 
			
		||||
// ReplyMessage 对话回复消息结构
 | 
			
		||||
type ReplyMessage struct {
 | 
			
		||||
	Type    WsMsgType   `json:"type"` // 消息类别,start, end, img
 | 
			
		||||
	Content interface{} `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
@@ -32,6 +32,13 @@ const (
 | 
			
		||||
	WsErr    = WsMsgType("error")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// InputMessage 对话输入消息结构
 | 
			
		||||
type InputMessage struct {
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
	Tools   []int  `json:"tools"`  // 允许调用工具列表
 | 
			
		||||
	Stream  bool   `json:"stream"` // 是否采用流式输出
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BizCode int
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
 
 | 
			
		||||
@@ -73,13 +73,12 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
	roleId := h.GetInt(c, "role_id", 0)
 | 
			
		||||
	chatId := c.Query("chat_id")
 | 
			
		||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
			
		||||
	tools := c.Query("tools")
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	var chatRole model.ChatRole
 | 
			
		||||
	res := h.DB.First(&chatRole, roleId)
 | 
			
		||||
	if res.Error != nil || !chatRole.Enable {
 | 
			
		||||
		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
 | 
			
		||||
		utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -91,7 +90,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
	var chatModel model.ChatModel
 | 
			
		||||
	res = h.DB.First(&chatModel, modelId)
 | 
			
		||||
	if res.Error != nil || chatModel.Enabled == false {
 | 
			
		||||
		utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
 | 
			
		||||
		utils.ReplyErrorMessage(client, "当前AI模型暂未启用,对话已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -100,7 +99,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
		SessionId: sessionId,
 | 
			
		||||
		ClientIP:  c.ClientIP(),
 | 
			
		||||
		UserId:    h.GetLoginUserId(c),
 | 
			
		||||
		Tools:     tools,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// use old chat data override the chat model and role ID
 | 
			
		||||
@@ -137,20 +135,16 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var message types.WsMessage
 | 
			
		||||
			var message types.InputMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 心跳消息
 | 
			
		||||
			if message.Type == "heartbeat" {
 | 
			
		||||
				logger.Debug("收到 Chat 心跳消息:", message.Content)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Info("Receive a message: ", message.Content)
 | 
			
		||||
			logger.Infof("Receive a message:%+v", message)
 | 
			
		||||
 | 
			
		||||
			session.Tools = message.Tools
 | 
			
		||||
			session.Stream = message.Stream
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			h.ReqCancelFunc.Put(sessionId, cancel)
 | 
			
		||||
			// 回复消息
 | 
			
		||||
@@ -159,7 +153,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				utils.ReplyMessage(client, err.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
				logger.Infof("回答完毕: %v", message.Content)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@@ -209,15 +203,20 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
 | 
			
		||||
	var req = types.ApiRequest{
 | 
			
		||||
		Model:       session.Model.Value,
 | 
			
		||||
		Stream: true,
 | 
			
		||||
		Temperature: session.Model.Temperature,
 | 
			
		||||
	}
 | 
			
		||||
	req.Temperature = session.Model.Temperature
 | 
			
		||||
	// 兼容 GPT-O1 模型
 | 
			
		||||
	if strings.HasPrefix(session.Model.Value, "o1-") {
 | 
			
		||||
		req.MaxCompletionTokens = session.Model.MaxTokens
 | 
			
		||||
		req.Stream = false
 | 
			
		||||
	} else {
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		req.Stream = session.Stream
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if session.Tools != "" {
 | 
			
		||||
		toolIds := strings.Split(session.Tools, ",")
 | 
			
		||||
	if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
 | 
			
		||||
		var items []model.Function
 | 
			
		||||
		res = h.DB.Where("enabled", true).Where("id IN ?", toolIds).Find(&items)
 | 
			
		||||
		res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
 | 
			
		||||
		if res.Error == nil {
 | 
			
		||||
			var tools = make([]types.Tool, 0)
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
@@ -279,7 +278,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
 | 
			
		||||
		for i := len(messages) - 1; i >= 0; i-- {
 | 
			
		||||
			v := messages[i]
 | 
			
		||||
			tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
			tks, _ = utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
			// 上下文 token 超出了模型的最大上下文长度
 | 
			
		||||
			if tokens+tks >= session.Model.MaxContext {
 | 
			
		||||
				break
 | 
			
		||||
@@ -500,10 +499,17 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Usage struct {
 | 
			
		||||
	Prompt           string
 | 
			
		||||
	Content          string
 | 
			
		||||
	PromptTokens     int
 | 
			
		||||
	CompletionTokens int
 | 
			
		||||
	TotalTokens      int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	contents []string,
 | 
			
		||||
	usage Usage,
 | 
			
		||||
	message types.Message,
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
@@ -514,8 +520,8 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
	if message.Role == "" {
 | 
			
		||||
		message.Role = "assistant"
 | 
			
		||||
	}
 | 
			
		||||
	message.Content = strings.Join(contents, "")
 | 
			
		||||
	useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
	message.Content = usage.Content
 | 
			
		||||
	useMsg := types.Message{Role: "user", Content: usage.Prompt}
 | 
			
		||||
 | 
			
		||||
	// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
	if h.App.SysConfig.EnableContext {
 | 
			
		||||
@@ -526,32 +532,41 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
 | 
			
		||||
	// 追加聊天记录
 | 
			
		||||
	// for prompt
 | 
			
		||||
	promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
	var promptTokens, replyTokens, totalTokens int
 | 
			
		||||
	if usage.PromptTokens > 0 {
 | 
			
		||||
		promptTokens = usage.PromptTokens
 | 
			
		||||
	} else {
 | 
			
		||||
		promptTokens, _ = utils.CalcTokens(usage.Content, req.Model)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	historyUserMsg := model.ChatMessage{
 | 
			
		||||
		UserId:      userVo.Id,
 | 
			
		||||
		ChatId:      session.ChatId,
 | 
			
		||||
		RoleId:      role.Id,
 | 
			
		||||
		Type:        types.PromptMsg,
 | 
			
		||||
		Icon:        userVo.Avatar,
 | 
			
		||||
		Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
		Tokens:     promptToken,
 | 
			
		||||
		Content:     template.HTMLEscapeString(usage.Prompt),
 | 
			
		||||
		Tokens:      promptTokens,
 | 
			
		||||
		TotalTokens: promptTokens,
 | 
			
		||||
		UseContext:  true,
 | 
			
		||||
		Model:       req.Model,
 | 
			
		||||
	}
 | 
			
		||||
	historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
	historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
	err = h.DB.Save(&historyUserMsg).Error
 | 
			
		||||
	err := h.DB.Save(&historyUserMsg).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("failed to save prompt history message: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// for reply
 | 
			
		||||
	// 计算本次对话消耗的总 token 数量
 | 
			
		||||
	replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
	totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
	if usage.CompletionTokens > 0 {
 | 
			
		||||
		replyTokens = usage.CompletionTokens
 | 
			
		||||
		totalTokens = usage.TotalTokens
 | 
			
		||||
	} else {
 | 
			
		||||
		replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
		totalTokens = replyTokens + getTotalTokens(req)
 | 
			
		||||
	}
 | 
			
		||||
	historyReplyMsg := model.ChatMessage{
 | 
			
		||||
		UserId:      userVo.Id,
 | 
			
		||||
		ChatId:      session.ChatId,
 | 
			
		||||
@@ -559,7 +574,8 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
		Type:        types.ReplyMsg,
 | 
			
		||||
		Icon:        role.Icon,
 | 
			
		||||
		Content:     message.Content,
 | 
			
		||||
		Tokens:     totalTokens,
 | 
			
		||||
		Tokens:      replyTokens,
 | 
			
		||||
		TotalTokens: totalTokens,
 | 
			
		||||
		UseContext:  true,
 | 
			
		||||
		Model:       req.Model,
 | 
			
		||||
	}
 | 
			
		||||
@@ -572,7 +588,7 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
 | 
			
		||||
	// 更新用户算力
 | 
			
		||||
	if session.Model.Power > 0 {
 | 
			
		||||
		h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
		h.subUserPower(userVo, session, promptTokens, replyTokens)
 | 
			
		||||
	}
 | 
			
		||||
	// 保存当前会话
 | 
			
		||||
	var chatItem model.ChatItem
 | 
			
		||||
@@ -582,10 +598,10 @@ func (h *ChatHandler) saveChatHistory(
 | 
			
		||||
		chatItem.UserId = userVo.Id
 | 
			
		||||
		chatItem.RoleId = role.Id
 | 
			
		||||
		chatItem.ModelId = session.Model.Id
 | 
			
		||||
		if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
			chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
		if utf8.RuneCountInString(usage.Prompt) > 30 {
 | 
			
		||||
			chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
 | 
			
		||||
		} else {
 | 
			
		||||
			chatItem.Title = prompt
 | 
			
		||||
			chatItem.Title = usage.Prompt
 | 
			
		||||
		}
 | 
			
		||||
		chatItem.Model = req.Model
 | 
			
		||||
		err = h.DB.Create(&chatItem).Error
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,28 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type respVo struct {
 | 
			
		||||
	Id                string `json:"id"`
 | 
			
		||||
	Object            string `json:"object"`
 | 
			
		||||
	Created           int    `json:"created"`
 | 
			
		||||
	Model             string `json:"model"`
 | 
			
		||||
	SystemFingerprint string `json:"system_fingerprint"`
 | 
			
		||||
	Choices           []struct {
 | 
			
		||||
		Index   int `json:"index"`
 | 
			
		||||
		Message struct {
 | 
			
		||||
			Role    string `json:"role"`
 | 
			
		||||
			Content string `json:"content"`
 | 
			
		||||
		} `json:"message"`
 | 
			
		||||
		Logprobs     interface{} `json:"logprobs"`
 | 
			
		||||
		FinishReason string      `json:"finish_reason"`
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
	Usage struct {
 | 
			
		||||
		PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
		CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
		TotalTokens      int `json:"total_tokens"`
 | 
			
		||||
	} `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OPenAI 消息发送实现
 | 
			
		||||
func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
@@ -49,6 +71,10 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if response.StatusCode != 200 {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body)
 | 
			
		||||
	}
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
@@ -106,8 +132,8 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					toolCall = true
 | 
			
		||||
					callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg})
 | 
			
		||||
					contents = append(contents, callMsg)
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
@@ -125,10 +151,10 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
				content := responseBody.Choices[0].Delta.Content
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(content))
 | 
			
		||||
				if isNew {
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
 | 
			
		||||
					isNew = false
 | 
			
		||||
				}
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
				})
 | 
			
		||||
@@ -161,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			}
 | 
			
		||||
			if errMsg != "" || apiRes.Code != types.Success {
 | 
			
		||||
				msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: msg,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, msg)
 | 
			
		||||
			} else {
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: apiRes.Data,
 | 
			
		||||
				})
 | 
			
		||||
@@ -177,11 +203,17 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
			usage := Usage{
 | 
			
		||||
				Prompt:           prompt,
 | 
			
		||||
				Content:          strings.Join(contents, ""),
 | 
			
		||||
				PromptTokens:     0,
 | 
			
		||||
				CompletionTokens: 0,
 | 
			
		||||
				TotalTokens:      0,
 | 
			
		||||
			}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
		return fmt.Errorf("请求 OpenAI API 失败:%s", body)
 | 
			
		||||
			h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
 | 
			
		||||
		}
 | 
			
		||||
	} else { // 非流式输出
 | 
			
		||||
		
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -73,7 +73,7 @@ func (h *DallJobHandler) Client(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var message types.WsMessage
 | 
			
		||||
			var message types.ReplyMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
 
 | 
			
		||||
@@ -64,7 +64,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var message types.WsMessage
 | 
			
		||||
			var message types.ReplyMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
@@ -85,7 +85,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
			
		||||
			err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
 | 
			
		||||
				utils.ReplyErrorMessage(client, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
@@ -170,16 +170,16 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if isNew {
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart})
 | 
			
		||||
				isNew = false
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(client, types.WsMessage{
 | 
			
		||||
			utils.ReplyChunkMessage(client, types.ReplyMessage{
 | 
			
		||||
				Type:    types.WsMiddle,
 | 
			
		||||
				Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
			})
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
		utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@ type ChatMessage struct {
 | 
			
		||||
	Type        string
 | 
			
		||||
	Icon        string
 | 
			
		||||
	Tokens      int
 | 
			
		||||
	TotalTokens int // 总 token 消耗
 | 
			
		||||
	Content     string
 | 
			
		||||
	UseContext  bool // 是否可以作为聊天上下文
 | 
			
		||||
	DeletedAt   gorm.DeletedAt
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ import "geekai/core/types"
 | 
			
		||||
type ChatRole struct {
 | 
			
		||||
	BaseVo
 | 
			
		||||
	Key       string          `json:"key"` // 角色唯一标识
 | 
			
		||||
	Tid       uint            `json:"tid"`
 | 
			
		||||
	Tid       int             `json:"tid"`
 | 
			
		||||
	Name      string          `json:"name"`       // 角色名称
 | 
			
		||||
	Context   []types.Message `json:"context"`    // 角色语料信息
 | 
			
		||||
	HelloMsg  string          `json:"hello_msg"`  // 打招呼的消息
 | 
			
		||||
 
 | 
			
		||||
@@ -33,9 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
 | 
			
		||||
 | 
			
		||||
// ReplyMessage 回复客户端一条完整的消息
 | 
			
		||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
 | 
			
		||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
 | 
			
		||||
	ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message})
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReplyErrorMessage 向客户端发送错误消息
 | 
			
		||||
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -10,3 +10,5 @@ CREATE TABLE `chatgpt_app_types` (
 | 
			
		||||
ALTER TABLE `chatgpt_app_types`ADD PRIMARY KEY (`id`);
 | 
			
		||||
ALTER TABLE `chatgpt_app_types` MODIFY `id` int NOT NULL AUTO_INCREMENT;
 | 
			
		||||
ALTER TABLE `chatgpt_chat_roles` ADD `tid` INT NOT NULL COMMENT '分类ID' AFTER `name`;
 | 
			
		||||
 | 
			
		||||
ALTER TABLE `chatgpt_chat_history` ADD `total_tokens` INT NOT NULL COMMENT '消耗总Token长度' AFTER `tokens`;
 | 
			
		||||
		Reference in New Issue
	
	Block a user