SSE 消息重构已完成

This commit is contained in:
GeekMaster
2025-05-27 15:48:07 +08:00
parent e685876cc0
commit 32fc4d86a2
15 changed files with 394 additions and 339 deletions

View File

@@ -21,11 +21,11 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"html/template"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
"unicode/utf8"
@@ -45,14 +45,17 @@ const (
)
type ChatInput struct {
UserId uint `json:"user_id"`
RoleId int `json:"role_id"`
ModelId int `json:"model_id"`
ChatId string `json:"chat_id"`
Content string `json:"content"`
Tools []int `json:"tools"`
Stream bool `json:"stream"`
Files []vo.File `json:"files"`
UserId uint `json:"user_id"`
RoleId uint `json:"role_id"`
ModelId uint `json:"model_id"`
ChatId string `json:"chat_id"`
Prompt string `json:"prompt"`
Tools []uint `json:"tools"`
Stream bool `json:"stream"`
Files []vo.File `json:"files"`
ChatModel model.ChatModel `json:"chat_model,omitempty"`
ChatRole model.ChatRole `json:"chat_role,omitempty"`
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID用于重新生成答案的时候过滤上下文
}
type ChatHandler struct {
@@ -79,14 +82,14 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
// Chat 处理聊天请求
func (h *ChatHandler) Chat(c *gin.Context) {
var data ChatInput
if err := c.ShouldBindJSON(&data); err != nil {
var input ChatInput
if err := c.ShouldBindJSON(&input); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Prompt-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
@@ -94,44 +97,34 @@ func (h *ChatHandler) Chat(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
// 使用旧的聊天数据覆盖模型和角色ID
var chat model.ChatItem
h.DB.Where("chat_id", input.ChatId).First(&chat)
if chat.Id > 0 {
input.ModelId = chat.ModelId
input.RoleId = chat.RoleId
}
// 验证聊天角色
var chatRole model.ChatRole
err := h.DB.First(&chatRole, data.RoleId).Error
err := h.DB.First(&chatRole, input.RoleId).Error
if err != nil || !chatRole.Enable {
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
return
}
// 如果角色绑定了模型ID使用角色的模型ID
if chatRole.ModelId > 0 {
data.ModelId = int(chatRole.ModelId)
}
input.ChatRole = chatRole
// 获取模型信息
var chatModel model.ChatModel
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
err = h.DB.Where("id", input.ModelId).First(&chatModel).Error
if err != nil || !chatModel.Enabled {
pushMessage(c, ChatEventError, "当前AI模型暂未启用请更换模型后再发起对话")
return
}
// 使用旧的聊天数据覆盖模型和角色ID
var chat model.ChatItem
h.DB.Where("chat_id", data.ChatId).First(&chat)
if chat.Id > 0 {
chatModel.Id = chat.ModelId
data.RoleId = int(chat.RoleId)
}
// 复制模型数据
err = utils.CopyObject(chatModel, &session.Model)
if err != nil {
logger.Error(err, chatModel)
}
session.Model.Id = chatModel.Id
input.ChatModel = chatModel
// 发送消息
err = h.sendMessage(ctx, session, chatRole, data.Content, c)
err = h.sendMessage(ctx, input, c)
if err != nil {
pushMessage(c, ChatEventError, err.Error())
return
@@ -148,9 +141,9 @@ func pushMessage(c *gin.Context, msgType string, content interface{}) {
c.Writer.Flush()
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, c *gin.Context) error {
func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.Context) error {
var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
res := h.DB.Model(&model.User{}).First(&user, input.UserId)
if res.Error != nil {
return errors.New("未授权用户,您正在进行非法操作!")
}
@@ -165,8 +158,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
}
if userVo.Power < session.Model.Power {
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d[立即购买](/member)。", userVo.Power, session.Model.Power)
if userVo.Power < input.ChatModel.Power {
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d[立即购买](/member)。", userVo.Power, input.ChatModel.Power)
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
@@ -174,30 +167,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, _ := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
promptTokens, _ := utils.CalcTokens(input.Prompt, input.ChatModel.Value)
if promptTokens > input.ChatModel.MaxContext {
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
}
var req = types.ApiRequest{
Model: session.Model.Value,
Stream: session.Stream,
Temperature: session.Model.Temperature,
Model: input.ChatModel.Value,
Stream: input.Stream,
Temperature: input.ChatModel.Temperature,
}
// 兼容 OpenAI 模型
if strings.HasPrefix(session.Model.Value, "o1-") ||
strings.HasPrefix(session.Model.Value, "o3-") ||
strings.HasPrefix(session.Model.Value, "gpt") {
req.MaxCompletionTokens = session.Model.MaxTokens
session.Start = time.Now().Unix()
if strings.HasPrefix(input.ChatModel.Value, "o1-") ||
strings.HasPrefix(input.ChatModel.Value, "o3-") ||
strings.HasPrefix(input.ChatModel.Value, "gpt") {
req.MaxCompletionTokens = input.ChatModel.MaxTokens
} else {
req.MaxTokens = session.Model.MaxTokens
req.MaxTokens = input.ChatModel.MaxTokens
}
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
if len(input.Tools) > 0 && !strings.HasPrefix(input.ChatModel.Value, "o1-") {
var items []model.Function
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
res = h.DB.Where("enabled", true).Where("id IN ?", input.Tools).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
@@ -231,14 +223,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
chatCtx := make([]interface{}, 0)
messages := make([]interface{}, 0)
if h.App.SysConfig.EnableContext {
if h.ChatContexts.Has(session.ChatId) {
messages = h.ChatContexts.Get(session.ChatId)
if h.ChatContexts.Has(input.ChatId) {
messages = h.ChatContexts.Get(input.ChatId)
} else {
_ = utils.JsonDecode(role.Context, &messages)
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
if input.LastMsgId > 0 { // 重新生成逻辑
dbSession = dbSession.Where("id < ?", input.LastMsgId)
}
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
if err == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
ms := types.Message{Role: "user", Content: msg.Content}
@@ -261,7 +257,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
v := messages[i]
tks, _ = utils.CalcTokens(utils.JsonEncode(v), req.Model)
// 上下文 token 超出了模型的最大上下文长度
if tokens+tks >= session.Model.MaxContext {
if tokens+tks >= input.ChatModel.MaxContext {
break
}
@@ -282,71 +278,101 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
reqMgs = append(reqMgs, chatCtx[i])
}
fullPrompt := prompt
text := prompt
for _, file := range session.Files {
// extract files in prompt
files := utils.ExtractFileURLs(prompt)
logger.Debugf("detected FILES: %+v", files)
// 如果不是逆向模型,则提取文件内容
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
strings.HasPrefix(session.Model.Value, "claude-3")) {
contents := make([]string, 0)
var file model.File
for _, v := range files {
h.DB.Where("url = ?", v).First(&file)
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
if err != nil {
logger.Error("error with read file: ", err)
} else {
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
text = strings.Replace(text, v, "", 1)
}
if len(contents) > 0 {
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML)\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
}
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
if tokens > session.Model.MaxContext {
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
}
}
logger.Debug("最终Prompt", fullPrompt)
// extract images from prompt
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
fileContents := make([]string, 0) // 文件内容
var finalPrompt = input.Prompt
imgList := make([]any, 0)
for _, file := range input.Files {
logger.Debugf("detected file: %+v", file.URL)
// 处理图片
if isImageURL(file.URL) {
imgList = append(imgList, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
"url": file.URL,
},
})
} else {
// 如果不是逆向模型,则提取文件内容
modelValue := input.ChatModel.Value
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) {
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
if err != nil {
logger.Error("error with read file: ", err)
continue
} else {
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
}
}
}
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
})
content = data
} else {
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
logger.Debugf("%+v", req.Messages)
if len(fileContents) > 0 {
finalPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML)\n\n %s\n\n 问题:%s", strings.Join(fileContents, "\n"), input.Prompt)
tokens, _ := utils.CalcTokens(finalPrompt, req.Model)
if tokens > input.ChatModel.MaxContext {
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
}
} else {
finalPrompt = input.Prompt
}
return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, c)
if len(imgList) > 0 {
imgList = append(imgList, map[string]interface{}{
"type": "text",
"text": input.Prompt,
})
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": imgList,
})
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": finalPrompt,
})
}
logger.Debugf("请求消息: %+v", req.Messages)
return h.sendOpenAiMessage(req, userVo, ctx, input, c)
}
// 判断一个 URL 是否图片链接
func isImageURL(url string) bool {
// 检查是否是有效的URL
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
return false
}
// 检查文件扩展名
ext := strings.ToLower(path.Ext(url))
validImageExts := map[string]bool{
".jpg": true,
".jpeg": true,
".png": true,
".gif": true,
".bmp": true,
".webp": true,
".svg": true,
".ico": true,
}
if !validImageExts[ext] {
return false
}
// 发送HEAD请求检查Content-Type
client := &http.Client{
Timeout: 5 * time.Second,
}
resp, err := client.Head(url)
if err != nil {
return false
}
defer resp.Body.Close()
contentType := resp.Header.Get("Content-Type")
return strings.HasPrefix(contentType, "image/")
}
// Tokens 统计 token 数量
@@ -415,10 +441,10 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
// if the chat model bind a KEY, use it directly
if session.Model.KeyId > 0 {
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
if input.ChatModel.KeyId > 0 {
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey)
} else { // use the last unused key
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
@@ -472,16 +498,16 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
}
// 扣减用户算力
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens int, replyTokens int) {
power := 1
if session.Model.Power > 0 {
power = session.Model.Power
if input.ChatModel.Power > 0 {
power = input.ChatModel.Power
}
err := h.userService.DecreasePower(userVo.Id, power, model.PowerLog{
Type: types.PowerConsume,
Model: session.Model.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", session.Model.Name, promptTokens, replyTokens),
Model: input.ChatModel.Value,
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d回复长度%d", input.ChatModel.Name, promptTokens, replyTokens),
})
if err != nil {
logger.Error(err)
@@ -492,8 +518,7 @@ func (h *ChatHandler) saveChatHistory(
req types.ApiRequest,
usage Usage,
message types.Message,
session *types.ChatSession,
role model.ChatRole,
input ChatInput,
userVo vo.User,
promptCreatedAt time.Time,
replyCreatedAt time.Time) {
@@ -502,7 +527,7 @@ func (h *ChatHandler) saveChatHistory(
if h.App.SysConfig.EnableContext {
chatCtx := req.Messages // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.ChatContexts.Put(session.ChatId, chatCtx)
h.ChatContexts.Put(input.ChatId, chatCtx)
}
// 追加聊天记录
@@ -515,12 +540,15 @@ func (h *ChatHandler) saveChatHistory(
}
historyUserMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: template.HTMLEscapeString(usage.Prompt),
UserId: userVo.Id,
ChatId: input.ChatId,
RoleId: input.RoleId,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: utils.JsonEncode(vo.MsgContent{
Text: usage.Prompt,
Files: input.Files,
}),
Tokens: promptTokens,
TotalTokens: promptTokens,
UseContext: true,
@@ -543,12 +571,15 @@ func (h *ChatHandler) saveChatHistory(
totalTokens = replyTokens + getTotalTokens(req)
}
historyReplyMsg := model.ChatMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: usage.Content,
UserId: userVo.Id,
ChatId: input.ChatId,
RoleId: input.RoleId,
Type: types.ReplyMsg,
Icon: input.ChatRole.Icon,
Content: utils.JsonEncode(vo.MsgContent{
Text: message.Content,
Files: input.Files,
}),
Tokens: replyTokens,
TotalTokens: totalTokens,
UseContext: true,
@@ -562,17 +593,17 @@ func (h *ChatHandler) saveChatHistory(
}
// 更新用户算力
if session.Model.Power > 0 {
h.subUserPower(userVo, session, promptTokens, replyTokens)
if input.ChatModel.Power > 0 {
h.subUserPower(userVo, input, promptTokens, replyTokens)
}
// 保存当前会话
var chatItem model.ChatItem
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
err = h.DB.Where("chat_id = ?", input.ChatId).First(&chatItem).Error
if err != nil {
chatItem.ChatId = session.ChatId
chatItem.ChatId = input.ChatId
chatItem.UserId = userVo.Id
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
chatItem.RoleId = input.RoleId
chatItem.ModelId = input.ModelId
if utf8.RuneCountInString(usage.Prompt) > 30 {
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
} else {
@@ -586,7 +617,7 @@ func (h *ChatHandler) saveChatHistory(
}
}
// 文本生成语音
// TextToSpeech 文本生成语音
func (h *ChatHandler) TextToSpeech(c *gin.Context) {
var data struct {
ModelId int `json:"model_id"`
@@ -600,13 +631,19 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
textHash := utils.Sha256(fmt.Sprintf("%d/%s", data.ModelId, data.Text))
audioFile := fmt.Sprintf("%s/audio", h.App.Config.StaticDir)
if _, err := os.Stat(audioFile); err != nil {
os.MkdirAll(audioFile, 0755)
resp.ERROR(c, err.Error())
return
}
if err := os.MkdirAll(audioFile, 0755); err != nil {
resp.ERROR(c, err.Error())
return
}
audioFile = fmt.Sprintf("%s/%s.mp3", audioFile, textHash)
if _, err := os.Stat(audioFile); err == nil {
// 设置响应头
c.Header("Content-Type", "audio/mpeg")
c.Header("Content-Disposition", "attachment; filename=speech.mp3")
c.Header("Prompt-Type", "audio/mpeg")
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
c.File(audioFile)
return
}
@@ -670,11 +707,14 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
}
// 设置响应头
c.Header("Content-Type", "audio/mpeg")
c.Header("Content-Disposition", "attachment; filename=speech.mp3")
c.Header("Prompt-Type", "audio/mpeg")
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
// 直接写入完整的音频数据到响应
c.Writer.Write(audioBytes)
_, err = c.Writer.Write(audioBytes)
if err != nil {
logger.Error("写入音频数据到响应失败:", err)
}
}
// // OPenAI 消息发送实现
@@ -707,7 +747,7 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
// }
// contentType := response.Header.Get("Content-Type")
// contentType := response.Header.Get("Prompt-Type")
// if strings.Contains(contentType, "text/event-stream") {
// replyCreatedAt := time.Now() // 记录回复时间
// // 循环读取 Chunk 消息
@@ -733,7 +773,7 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
// continue
// }
// if responseBody.Choices[0].Delta.Content == nil &&
// if responseBody.Choices[0].Delta.Prompt == nil &&
// responseBody.Choices[0].Delta.ToolCalls == nil &&
// responseBody.Choices[0].Delta.ReasoningContent == "" {
// continue
@@ -799,10 +839,10 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// "content": reasoningContent,
// })
// contents = append(contents, reasoningContent)
// } else if responseBody.Choices[0].Delta.Content != "" {
// finalContent := responseBody.Choices[0].Delta.Content
// } else if responseBody.Choices[0].Delta.Prompt != "" {
// finalContent := responseBody.Choices[0].Delta.Prompt
// if reasoning {
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content)
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
// reasoning = false
// }
// contents = append(contents, utils.InterfaceToString(finalContent))
@@ -861,12 +901,12 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// if len(contents) > 0 {
// usage := Usage{
// Prompt: prompt,
// Content: strings.Join(contents, ""),
// Prompt: strings.Join(contents, ""),
// PromptTokens: 0,
// CompletionTokens: 0,
// TotalTokens: 0,
// }
// message.Content = usage.Content
// message.Prompt = usage.Prompt
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
// }
// } else {
@@ -879,16 +919,16 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// if err != nil {
// return fmt.Errorf("解析响应失败:%v", body)
// }
// content := respVo.Choices[0].Message.Content
// content := respVo.Choices[0].Message.Prompt
// if strings.HasPrefix(req.Model, "o1-") {
// content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
// content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
// }
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
// "type": "text",
// "content": content,
// })
// respVo.Usage.Prompt = prompt
// respVo.Usage.Content = content
// respVo.Usage.Prompt = content
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
// }