mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	refactor: 更改 OpenAI 请求 Body 数据结构,兼容函数调用请求
This commit is contained in:
		@@ -22,8 +22,8 @@ type AppServer struct {
 | 
			
		||||
	Debug        bool
 | 
			
		||||
	Config       *types.AppConfig
 | 
			
		||||
	Engine       *gin.Engine
 | 
			
		||||
	ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
	ChatConfig   *types.ChatConfig                    // 聊天配置
 | 
			
		||||
	ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
	ChatConfig   *types.ChatConfig                  // 聊天配置
 | 
			
		||||
 | 
			
		||||
	// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
 | 
			
		||||
	// 防止第三方直接连接 socket 调用 OpenAI API
 | 
			
		||||
@@ -39,7 +39,7 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
			
		||||
		Debug:         false,
 | 
			
		||||
		Config:        appConfig,
 | 
			
		||||
		Engine:        gin.Default(),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []types.Message](),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []interface{}](),
 | 
			
		||||
		ChatSession:   types.NewLMap[string, types.ChatSession](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
 
 | 
			
		||||
@@ -2,17 +2,17 @@ package types
 | 
			
		||||
 | 
			
		||||
// ApiRequest API 请求实体
 | 
			
		||||
type ApiRequest struct {
 | 
			
		||||
	Model       string    `json:"model"`
 | 
			
		||||
	Temperature float32   `json:"temperature"`
 | 
			
		||||
	MaxTokens   int       `json:"max_tokens"`
 | 
			
		||||
	Stream      bool      `json:"stream"`
 | 
			
		||||
	Messages    []Message `json:"messages"`
 | 
			
		||||
	Model       string        `json:"model"`
 | 
			
		||||
	Temperature float32       `json:"temperature"`
 | 
			
		||||
	MaxTokens   int           `json:"max_tokens"`
 | 
			
		||||
	Stream      bool          `json:"stream"`
 | 
			
		||||
	Messages    []interface{} `json:"messages"`
 | 
			
		||||
	Functions   []Function    `json:"functions"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role         string       `json:"role"`
 | 
			
		||||
	Content      string       `json:"content"`
 | 
			
		||||
	FunctionCall FunctionCall `json:"function_call"`
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ApiResponse struct {
 | 
			
		||||
@@ -21,8 +21,15 @@ type ApiResponse struct {
 | 
			
		||||
 | 
			
		||||
// ChoiceItem API 响应实体
 | 
			
		||||
type ChoiceItem struct {
 | 
			
		||||
	Delta        Message `json:"delta"`
 | 
			
		||||
	FinishReason string  `json:"finish_reason"`
 | 
			
		||||
	Delta        Delta  `json:"delta"`
 | 
			
		||||
	FinishReason string `json:"finish_reason"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Delta struct {
 | 
			
		||||
	Role         string       `json:"role"`
 | 
			
		||||
	Name         string       `json:"name"`
 | 
			
		||||
	Content      interface{}  `json:"content"`
 | 
			
		||||
	FunctionCall FunctionCall `json:"function_call,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatSession 聊天会话对象
 | 
			
		||||
 
 | 
			
		||||
@@ -6,13 +6,71 @@ type FunctionCall struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Function struct {
 | 
			
		||||
	Name        string
 | 
			
		||||
	Description string
 | 
			
		||||
	Parameters  []Parameter
 | 
			
		||||
	Name        string     `json:"name"`
 | 
			
		||||
	Description string     `json:"description"`
 | 
			
		||||
	Parameters  Parameters `json:"parameters"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Parameter struct {
 | 
			
		||||
	Type       string
 | 
			
		||||
	Required   []string
 | 
			
		||||
	Properties map[string]interface{}
 | 
			
		||||
type Parameters struct {
 | 
			
		||||
	Type       string              `json:"type"`
 | 
			
		||||
	Required   []string            `json:"required"`
 | 
			
		||||
	Properties map[string]Property `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Property struct {
 | 
			
		||||
	Type        string `json:"type"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var InnerFunctions = []Function{
 | 
			
		||||
	{
 | 
			
		||||
		Name:        "zao_bao",
 | 
			
		||||
		Description: "每日早报,获取当天全球的热门新闻事件列表",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"text": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		Name:        "weibo_hot",
 | 
			
		||||
		Description: "新浪微博热搜榜,微博当日热搜榜单",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"text": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		Name:        "zhihu_top",
 | 
			
		||||
		Description: "知乎热榜,知乎当日话题讨论榜单",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"text": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var FunctionNameMap = map[string]string{
 | 
			
		||||
	"zao_bao":   "每日早报",
 | 
			
		||||
	"weibo_hot": "微博热搜",
 | 
			
		||||
	"zhihu_top": "知乎热榜",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ type MKey interface {
 | 
			
		||||
	string | int
 | 
			
		||||
}
 | 
			
		||||
type MValue interface {
 | 
			
		||||
	*WsClient | ChatSession | []Message | context.CancelFunc
 | 
			
		||||
	*WsClient | ChatSession | context.CancelFunc | []interface{}
 | 
			
		||||
}
 | 
			
		||||
type LMap[K MKey, T MValue] struct {
 | 
			
		||||
	lock sync.RWMutex
 | 
			
		||||
 
 | 
			
		||||
@@ -157,10 +157,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
			
		||||
		Temperature: userVo.ChatConfig.Temperature,
 | 
			
		||||
		MaxTokens:   userVo.ChatConfig.MaxTokens,
 | 
			
		||||
		Stream:      true,
 | 
			
		||||
		Functions:   types.InnerFunctions,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加载聊天上下文
 | 
			
		||||
	var chatCtx []types.Message
 | 
			
		||||
	var chatCtx []interface{}
 | 
			
		||||
	if userVo.ChatConfig.EnableContext {
 | 
			
		||||
		if h.App.ChatContexts.Has(session.ChatId) {
 | 
			
		||||
			chatCtx = h.App.ChatContexts.Get(session.ChatId)
 | 
			
		||||
@@ -169,11 +170,13 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
			
		||||
			var messages []types.Message
 | 
			
		||||
			err := utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				chatCtx = messages
 | 
			
		||||
				for _, v := range messages {
 | 
			
		||||
					chatCtx = append(chatCtx, v)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// TODO: 这里默认加载最近 4 条聊天记录作为上下文,后期应该做成可配置的
 | 
			
		||||
			// TODO: 这里默认加载最近 2 条聊天记录作为上下文,后期应该做成可配置的
 | 
			
		||||
			var historyMessages []model.HistoryMessage
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).Limit(4).Order("created_at desc").Find(&historyMessages)
 | 
			
		||||
			res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(2).Order("created_at desc").Find(&historyMessages)
 | 
			
		||||
			if res.Error == nil {
 | 
			
		||||
				for _, msg := range historyMessages {
 | 
			
		||||
					ms := types.Message{Role: "user", Content: msg.Content}
 | 
			
		||||
@@ -189,12 +192,17 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
			
		||||
			logger.Info("聊天上下文:", chatCtx)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	req.Messages = append(chatCtx, types.Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
		Content: prompt,
 | 
			
		||||
	reqMgs := make([]interface{}, 0)
 | 
			
		||||
	for _, m := range chatCtx {
 | 
			
		||||
		reqMgs = append(reqMgs, m)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
		"role":    "user",
 | 
			
		||||
		"content": prompt,
 | 
			
		||||
	})
 | 
			
		||||
	var apiKey string
 | 
			
		||||
	response, err := h.fakeRequest(ctx, userVo, &apiKey, req)
 | 
			
		||||
	response, err := h.doRequest(ctx, userVo, &apiKey, req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
@@ -213,176 +221,191 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	//if strings.Contains(contentType, "text/event-stream") || true {
 | 
			
		||||
	if true {
 | 
			
		||||
		replyCreatedAt := time.Now()
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var functionCall = false
 | 
			
		||||
		var functionName string
 | 
			
		||||
		var arguments = make([]string, 0)
 | 
			
		||||
		reader := bufio.NewReader(response.Body)
 | 
			
		||||
		for {
 | 
			
		||||
			line, err := reader.ReadString('\n')
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
					logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
 | 
			
		||||
				logger.Error(err, line)
 | 
			
		||||
				replyMessage(ws, ErrorMsg)
 | 
			
		||||
				replyMessage(ws, "")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			fun := responseBody.Choices[0].Delta.FunctionCall
 | 
			
		||||
			if functionCall && fun.Name == "" {
 | 
			
		||||
				arguments = append(arguments, fun.Arguments)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !utils.IsEmptyValue(fun) {
 | 
			
		||||
				functionCall = true
 | 
			
		||||
				functionName = fun.Name
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 %s 作答 ...\n\n", functionName)})
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 初始化 role
 | 
			
		||||
			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
 | 
			
		||||
				message.Role = responseBody.Choices[0].Delta.Role
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				continue
 | 
			
		||||
			} else if responseBody.Choices[0].FinishReason != "" {
 | 
			
		||||
				break // 输出完成或者输出中断了
 | 
			
		||||
			} else {
 | 
			
		||||
				content := responseBody.Choices[0].Delta.Content
 | 
			
		||||
				contents = append(contents, content)
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: responseBody.Choices[0].Delta.Content,
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		if functionCall { // 调用函数完成任务
 | 
			
		||||
			// TODO 调用函数完成任务
 | 
			
		||||
			data, err := h.funcZaoBao.Fetch()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: "调用函数出错",
 | 
			
		||||
				})
 | 
			
		||||
			} else {
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: data,
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
			contents = append(contents, data)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				return res.Error
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if userVo.ChatConfig.EnableContext && functionCall == false {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if userVo.ChatConfig.EnableHistory {
 | 
			
		||||
				// for prompt
 | 
			
		||||
				token, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		if true {
 | 
			
		||||
			replyCreatedAt := time.Now()
 | 
			
		||||
			// 循环读取 Chunk 消息
 | 
			
		||||
			var message = types.Message{}
 | 
			
		||||
			var contents = make([]string, 0)
 | 
			
		||||
			var functionCall = false
 | 
			
		||||
			var functionName string
 | 
			
		||||
			var arguments = make([]string, 0)
 | 
			
		||||
			reader := bufio.NewReader(response.Body)
 | 
			
		||||
			for {
 | 
			
		||||
				line, err := reader.ReadString('\n')
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
					if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
						logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
					} else if err != io.EOF {
 | 
			
		||||
						logger.Error(err)
 | 
			
		||||
					}
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:  userVo.Id,
 | 
			
		||||
					ChatId:  session.ChatId,
 | 
			
		||||
					RoleId:  role.Id,
 | 
			
		||||
					Type:    types.PromptMsg,
 | 
			
		||||
					Icon:    user.Avatar,
 | 
			
		||||
					Content: prompt,
 | 
			
		||||
					Tokens:  token,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for reply
 | 
			
		||||
				token, err = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				var responseBody = types.ApiResponse{}
 | 
			
		||||
				err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
				if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
 | 
			
		||||
					logger.Error(err, line)
 | 
			
		||||
					replyMessage(ws, ErrorMsg)
 | 
			
		||||
					replyMessage(ws, "")
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				fun := responseBody.Choices[0].Delta.FunctionCall
 | 
			
		||||
				if functionCall && fun.Name == "" {
 | 
			
		||||
					arguments = append(arguments, fun.Arguments)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if !utils.IsEmptyValue(fun) {
 | 
			
		||||
					functionCall = true
 | 
			
		||||
					functionName = fun.Name
 | 
			
		||||
					replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", types.FunctionNameMap[functionName])})
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 初始化 role
 | 
			
		||||
				if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
 | 
			
		||||
					message.Role = responseBody.Choices[0].Delta.Role
 | 
			
		||||
					replyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					continue
 | 
			
		||||
				} else if responseBody.Choices[0].FinishReason != "" {
 | 
			
		||||
					break // 输出完成或者输出中断了
 | 
			
		||||
				} else {
 | 
			
		||||
					content := responseBody.Choices[0].Delta.Content
 | 
			
		||||
					contents = append(contents, utils.InterfaceToString(content))
 | 
			
		||||
					replyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
					})
 | 
			
		||||
				}
 | 
			
		||||
			} // end for
 | 
			
		||||
 | 
			
		||||
			if functionCall { // 调用函数完成任务
 | 
			
		||||
				logger.Info(functionName)
 | 
			
		||||
				logger.Info(arguments)
 | 
			
		||||
				// TODO 调用函数完成任务
 | 
			
		||||
				data, err := h.funcZaoBao.Fetch()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
					replyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: "调用函数出错",
 | 
			
		||||
					})
 | 
			
		||||
				} else {
 | 
			
		||||
					replyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: data,
 | 
			
		||||
					})
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:  userVo.Id,
 | 
			
		||||
					ChatId:  session.ChatId,
 | 
			
		||||
					RoleId:  role.Id,
 | 
			
		||||
					Type:    types.ReplyMsg,
 | 
			
		||||
					Icon:    role.Icon,
 | 
			
		||||
					Content: message.Content,
 | 
			
		||||
					Tokens:  token,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 统计用户 token 数量
 | 
			
		||||
				h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
 | 
			
		||||
					historyUserMsg.Tokens+historyReplyMsg.Tokens))
 | 
			
		||||
				contents = append(contents, data)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.Model = session.Model
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
			// 消息发送成功
 | 
			
		||||
			if len(contents) > 0 {
 | 
			
		||||
				// 更新用户的对话次数
 | 
			
		||||
				res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					return res.Error
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if message.Role == "" {
 | 
			
		||||
					message.Role = "assistant"
 | 
			
		||||
				}
 | 
			
		||||
				message.Content = strings.Join(contents, "")
 | 
			
		||||
				useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				req.Messages = append(req.Messages, message)
 | 
			
		||||
				totalTokens := getTotalTokens(req)
 | 
			
		||||
				replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("`本轮对话共消耗 Token 数量: %d`", totalTokens)})
 | 
			
		||||
 | 
			
		||||
				// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
				if userVo.ChatConfig.EnableContext && functionCall == false {
 | 
			
		||||
					chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
					chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
					h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 追加聊天记录
 | 
			
		||||
				if userVo.ChatConfig.EnableHistory {
 | 
			
		||||
					useContext := true
 | 
			
		||||
					if functionCall {
 | 
			
		||||
						useContext = false
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// for prompt
 | 
			
		||||
					token, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logger.Error(err)
 | 
			
		||||
					}
 | 
			
		||||
					historyUserMsg := model.HistoryMessage{
 | 
			
		||||
						UserId:     userVo.Id,
 | 
			
		||||
						ChatId:     session.ChatId,
 | 
			
		||||
						RoleId:     role.Id,
 | 
			
		||||
						Type:       types.PromptMsg,
 | 
			
		||||
						Icon:       user.Avatar,
 | 
			
		||||
						Content:    prompt,
 | 
			
		||||
						Tokens:     token,
 | 
			
		||||
						UseContext: useContext,
 | 
			
		||||
					}
 | 
			
		||||
					historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
					historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
					res := h.db.Save(&historyUserMsg)
 | 
			
		||||
					if res.Error != nil {
 | 
			
		||||
						logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// for reply
 | 
			
		||||
					token, err = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logger.Error(err)
 | 
			
		||||
					}
 | 
			
		||||
					historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
						UserId:     userVo.Id,
 | 
			
		||||
						ChatId:     session.ChatId,
 | 
			
		||||
						RoleId:     role.Id,
 | 
			
		||||
						Type:       types.ReplyMsg,
 | 
			
		||||
						Icon:       role.Icon,
 | 
			
		||||
						Content:    message.Content,
 | 
			
		||||
						Tokens:     token,
 | 
			
		||||
						UseContext: useContext,
 | 
			
		||||
					}
 | 
			
		||||
					historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
					historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
					res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
					if res.Error != nil {
 | 
			
		||||
						logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// 统计用户 token 数量
 | 
			
		||||
					h.db.Model(&user).UpdateColumn("tokens", gorm.Expr("tokens + ?",
 | 
			
		||||
						historyUserMsg.Tokens+historyReplyMsg.Tokens))
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 保存当前会话
 | 
			
		||||
				var chatItem model.ChatItem
 | 
			
		||||
				res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					chatItem.ChatId = session.ChatId
 | 
			
		||||
					chatItem.UserId = session.UserId
 | 
			
		||||
					chatItem.RoleId = role.Id
 | 
			
		||||
					chatItem.Model = session.Model
 | 
			
		||||
					if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
						chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
					} else {
 | 
			
		||||
						chatItem.Title = prompt
 | 
			
		||||
					}
 | 
			
		||||
					h.db.Create(&chatItem)
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -469,13 +492,6 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) fakeRequest(ctx context.Context, user vo.User, apiKey *string, req types.ApiRequest) (*http.Response, error) {
 | 
			
		||||
	link := "https://img.r9it.com/chatgpt/response"
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	request, _ := http.NewRequest(http.MethodGet, link, nil)
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 回复客户片段端消息
 | 
			
		||||
func replyChunkMessage(client types.Client, message types.WsMessage) {
 | 
			
		||||
	msg, err := json.Marshal(message)
 | 
			
		||||
@@ -509,6 +525,26 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, tokens)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTotalTokens(req types.ApiRequest) int {
 | 
			
		||||
	encode := utils.JsonEncode(req.Messages)
 | 
			
		||||
	var items []map[string]interface{}
 | 
			
		||||
	err := utils.JsonDecode(encode, &items)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	tokens := 0
 | 
			
		||||
	for _, item := range items {
 | 
			
		||||
		content, ok := item["content"]
 | 
			
		||||
		if ok && !utils.IsEmptyValue(content) {
 | 
			
		||||
			t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				tokens += t
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return tokens
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StopGenerate 停止生成
 | 
			
		||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
 | 
			
		||||
	sessionId := c.Query("session_id")
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@ import (
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"context"
 | 
			
		||||
	"embed"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -101,9 +102,12 @@ func main() {
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 创建函数
 | 
			
		||||
		fx.Provide(func() *function.FuncZaoBao {
 | 
			
		||||
		fx.Provide(func() (*function.FuncZaoBao, error) {
 | 
			
		||||
			token := os.Getenv("AL_API_TOKEN")
 | 
			
		||||
			return function.NewZaoBao(token)
 | 
			
		||||
			if token == "" {
 | 
			
		||||
				return nil, errors.New("invalid AL api token")
 | 
			
		||||
			}
 | 
			
		||||
			return function.NewZaoBao(token), nil
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 创建控制器
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,14 @@ package model
 | 
			
		||||
 | 
			
		||||
type HistoryMessage struct {
 | 
			
		||||
	BaseModel
 | 
			
		||||
	ChatId  string // 会话 ID
 | 
			
		||||
	UserId  uint   // 用户 ID
 | 
			
		||||
	RoleId  uint   // 角色 ID
 | 
			
		||||
	Type    string
 | 
			
		||||
	Icon    string
 | 
			
		||||
	Tokens  int
 | 
			
		||||
	Content string
 | 
			
		||||
	ChatId     string // 会话 ID
 | 
			
		||||
	UserId     uint   // 用户 ID
 | 
			
		||||
	RoleId     uint   // 角色 ID
 | 
			
		||||
	Type       string
 | 
			
		||||
	Icon       string
 | 
			
		||||
	Tokens     int
 | 
			
		||||
	Content    string
 | 
			
		||||
	UseContext bool // 是否可以作为聊天上下文
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (HistoryMessage) TableName() string {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,14 @@ package vo
 | 
			
		||||
 | 
			
		||||
type HistoryMessage struct {
 | 
			
		||||
	BaseVo
 | 
			
		||||
	ChatId  string `json:"chat_id"`
 | 
			
		||||
	UserId  uint   `json:"user_id"`
 | 
			
		||||
	RoleId  uint   `json:"role_id"`
 | 
			
		||||
	Type    string `json:"type"`
 | 
			
		||||
	Icon    string `json:"icon"`
 | 
			
		||||
	Tokens  int    `json:"tokens"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
	ChatId     string `json:"chat_id"`
 | 
			
		||||
	UserId     uint   `json:"user_id"`
 | 
			
		||||
	RoleId     uint   `json:"role_id"`
 | 
			
		||||
	Type       string `json:"type"`
 | 
			
		||||
	Icon       string `json:"icon"`
 | 
			
		||||
	Tokens     int    `json:"tokens"`
 | 
			
		||||
	Content    string `json:"content"`
 | 
			
		||||
	UseContext bool   `json:"use_context"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (HistoryMessage) TableName() string {
 | 
			
		||||
 
 | 
			
		||||
@@ -81,3 +81,10 @@ func JsonEncode(value interface{}) string {
 | 
			
		||||
func JsonDecode(src string, dest interface{}) error {
 | 
			
		||||
	return json.Unmarshal([]byte(src), dest)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InterfaceToString(value interface{}) string {
 | 
			
		||||
	if str, ok := value.(string); ok {
 | 
			
		||||
		return str
 | 
			
		||||
	}
 | 
			
		||||
	return JsonEncode(value)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								database/plugins.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								database/plugins.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
ALTER TABLE `chatgpt_chat_history` ADD `use_context` TINYINT(1) NOT NULL COMMENT '是否允许作为上下文语料' AFTER `tokens`;
 | 
			
		||||
@@ -94,6 +94,7 @@
 | 
			
		||||
      <el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
 | 
			
		||||
        <div class="chat-head">
 | 
			
		||||
          <div class="chat-config">
 | 
			
		||||
            <span class="role-select-label">聊天角色:</span>
 | 
			
		||||
            <el-select v-model="roleId" filterable placeholder="角色" class="role-select">
 | 
			
		||||
              <el-option
 | 
			
		||||
                  v-for="item in roles"
 | 
			
		||||
@@ -210,7 +211,8 @@ import {
 | 
			
		||||
  Check,
 | 
			
		||||
  Close,
 | 
			
		||||
  Delete,
 | 
			
		||||
  Edit, Iphone,
 | 
			
		||||
  Edit,
 | 
			
		||||
  Iphone,
 | 
			
		||||
  Plus,
 | 
			
		||||
  Promotion,
 | 
			
		||||
  RefreshRight,
 | 
			
		||||
@@ -920,6 +922,10 @@ $borderColor = #4676d0;
 | 
			
		||||
          justify-content center;
 | 
			
		||||
          padding-top 10px;
 | 
			
		||||
 | 
			
		||||
          .role-select-label {
 | 
			
		||||
            color #ffffff
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          .el-select {
 | 
			
		||||
            //max-width 150px;
 | 
			
		||||
            margin-right 10px;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user