mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-02-15 10:54:24 +08:00
SSE 消息重构已完成
This commit is contained in:
@@ -21,11 +21,11 @@ import (
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -45,14 +45,17 @@ const (
|
||||
)
|
||||
|
||||
type ChatInput struct {
|
||||
UserId uint `json:"user_id"`
|
||||
RoleId int `json:"role_id"`
|
||||
ModelId int `json:"model_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
Tools []int `json:"tools"`
|
||||
Stream bool `json:"stream"`
|
||||
Files []vo.File `json:"files"`
|
||||
UserId uint `json:"user_id"`
|
||||
RoleId uint `json:"role_id"`
|
||||
ModelId uint `json:"model_id"`
|
||||
ChatId string `json:"chat_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Tools []uint `json:"tools"`
|
||||
Stream bool `json:"stream"`
|
||||
Files []vo.File `json:"files"`
|
||||
ChatModel model.ChatModel `json:"chat_model,omitempty"`
|
||||
ChatRole model.ChatRole `json:"chat_role,omitempty"`
|
||||
LastMsgId uint `json:"last_msg_id,omitempty"` // 最后的消息ID,用于重新生成答案的时候过滤上下文
|
||||
}
|
||||
|
||||
type ChatHandler struct {
|
||||
@@ -79,14 +82,14 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
|
||||
|
||||
// Chat 处理聊天请求
|
||||
func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
var data ChatInput
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
var input ChatInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Prompt-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
@@ -94,44 +97,34 @@ func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
// 使用旧的聊天数据覆盖模型和角色ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", input.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
input.ModelId = chat.ModelId
|
||||
input.RoleId = chat.RoleId
|
||||
}
|
||||
|
||||
// 验证聊天角色
|
||||
var chatRole model.ChatRole
|
||||
err := h.DB.First(&chatRole, data.RoleId).Error
|
||||
err := h.DB.First(&chatRole, input.RoleId).Error
|
||||
if err != nil || !chatRole.Enable {
|
||||
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果角色绑定了模型ID,使用角色的模型ID
|
||||
if chatRole.ModelId > 0 {
|
||||
data.ModelId = int(chatRole.ModelId)
|
||||
}
|
||||
input.ChatRole = chatRole
|
||||
|
||||
// 获取模型信息
|
||||
var chatModel model.ChatModel
|
||||
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
|
||||
err = h.DB.Where("id", input.ModelId).First(&chatModel).Error
|
||||
if err != nil || !chatModel.Enabled {
|
||||
pushMessage(c, ChatEventError, "当前AI模型暂未启用,请更换模型后再发起对话!")
|
||||
return
|
||||
}
|
||||
|
||||
// 使用旧的聊天数据覆盖模型和角色ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", data.ChatId).First(&chat)
|
||||
if chat.Id > 0 {
|
||||
chatModel.Id = chat.ModelId
|
||||
data.RoleId = int(chat.RoleId)
|
||||
}
|
||||
|
||||
// 复制模型数据
|
||||
err = utils.CopyObject(chatModel, &session.Model)
|
||||
if err != nil {
|
||||
logger.Error(err, chatModel)
|
||||
}
|
||||
session.Model.Id = chatModel.Id
|
||||
input.ChatModel = chatModel
|
||||
|
||||
// 发送消息
|
||||
err = h.sendMessage(ctx, session, chatRole, data.Content, c)
|
||||
err = h.sendMessage(ctx, input, c)
|
||||
if err != nil {
|
||||
pushMessage(c, ChatEventError, err.Error())
|
||||
return
|
||||
@@ -148,9 +141,9 @@ func pushMessage(c *gin.Context, msgType string, content interface{}) {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, c *gin.Context) error {
|
||||
func (h *ChatHandler) sendMessage(ctx context.Context, input ChatInput, c *gin.Context) error {
|
||||
var user model.User
|
||||
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
|
||||
res := h.DB.Model(&model.User{}).First(&user, input.UserId)
|
||||
if res.Error != nil {
|
||||
return errors.New("未授权用户,您正在进行非法操作!")
|
||||
}
|
||||
@@ -165,8 +158,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
|
||||
}
|
||||
|
||||
if userVo.Power < session.Model.Power {
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power)
|
||||
if userVo.Power < input.ChatModel.Power {
|
||||
return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, input.ChatModel.Power)
|
||||
}
|
||||
|
||||
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
||||
@@ -174,30 +167,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
}
|
||||
|
||||
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
|
||||
promptTokens, _ := utils.CalcTokens(prompt, session.Model.Value)
|
||||
if promptTokens > session.Model.MaxContext {
|
||||
promptTokens, _ := utils.CalcTokens(input.Prompt, input.ChatModel.Value)
|
||||
if promptTokens > input.ChatModel.MaxContext {
|
||||
|
||||
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
|
||||
}
|
||||
|
||||
var req = types.ApiRequest{
|
||||
Model: session.Model.Value,
|
||||
Stream: session.Stream,
|
||||
Temperature: session.Model.Temperature,
|
||||
Model: input.ChatModel.Value,
|
||||
Stream: input.Stream,
|
||||
Temperature: input.ChatModel.Temperature,
|
||||
}
|
||||
// 兼容 OpenAI 模型
|
||||
if strings.HasPrefix(session.Model.Value, "o1-") ||
|
||||
strings.HasPrefix(session.Model.Value, "o3-") ||
|
||||
strings.HasPrefix(session.Model.Value, "gpt") {
|
||||
req.MaxCompletionTokens = session.Model.MaxTokens
|
||||
session.Start = time.Now().Unix()
|
||||
if strings.HasPrefix(input.ChatModel.Value, "o1-") ||
|
||||
strings.HasPrefix(input.ChatModel.Value, "o3-") ||
|
||||
strings.HasPrefix(input.ChatModel.Value, "gpt") {
|
||||
req.MaxCompletionTokens = input.ChatModel.MaxTokens
|
||||
} else {
|
||||
req.MaxTokens = session.Model.MaxTokens
|
||||
req.MaxTokens = input.ChatModel.MaxTokens
|
||||
}
|
||||
|
||||
if len(session.Tools) > 0 && !strings.HasPrefix(session.Model.Value, "o1-") {
|
||||
if len(input.Tools) > 0 && !strings.HasPrefix(input.ChatModel.Value, "o1-") {
|
||||
var items []model.Function
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", session.Tools).Find(&items)
|
||||
res = h.DB.Where("enabled", true).Where("id IN ?", input.Tools).Find(&items)
|
||||
if res.Error == nil {
|
||||
var tools = make([]types.Tool, 0)
|
||||
for _, v := range items {
|
||||
@@ -231,14 +223,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
chatCtx := make([]interface{}, 0)
|
||||
messages := make([]interface{}, 0)
|
||||
if h.App.SysConfig.EnableContext {
|
||||
if h.ChatContexts.Has(session.ChatId) {
|
||||
messages = h.ChatContexts.Get(session.ChatId)
|
||||
if h.ChatContexts.Has(input.ChatId) {
|
||||
messages = h.ChatContexts.Get(input.ChatId)
|
||||
} else {
|
||||
_ = utils.JsonDecode(role.Context, &messages)
|
||||
_ = utils.JsonDecode(input.ChatRole.Context, &messages)
|
||||
if h.App.SysConfig.ContextDeep > 0 {
|
||||
var historyMessages []model.ChatMessage
|
||||
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
|
||||
if res.Error == nil {
|
||||
dbSession := h.DB.Session(&gorm.Session{}).Where("chat_id", input.ChatId)
|
||||
if input.LastMsgId > 0 { // 重新生成逻辑
|
||||
dbSession = dbSession.Where("id < ?", input.LastMsgId)
|
||||
}
|
||||
err = dbSession.Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages).Error
|
||||
if err == nil {
|
||||
for i := len(historyMessages) - 1; i >= 0; i-- {
|
||||
msg := historyMessages[i]
|
||||
ms := types.Message{Role: "user", Content: msg.Content}
|
||||
@@ -261,7 +257,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
v := messages[i]
|
||||
tks, _ = utils.CalcTokens(utils.JsonEncode(v), req.Model)
|
||||
// 上下文 token 超出了模型的最大上下文长度
|
||||
if tokens+tks >= session.Model.MaxContext {
|
||||
if tokens+tks >= input.ChatModel.MaxContext {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -282,71 +278,101 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
||||
reqMgs = append(reqMgs, chatCtx[i])
|
||||
}
|
||||
|
||||
fullPrompt := prompt
|
||||
text := prompt
|
||||
|
||||
for _, file := range session.Files {
|
||||
// extract files in prompt
|
||||
files := utils.ExtractFileURLs(prompt)
|
||||
logger.Debugf("detected FILES: %+v", files)
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
if len(files) > 0 && !(session.Model.Value == "gpt-4-all" ||
|
||||
strings.HasPrefix(session.Model.Value, "gpt-4-gizmo") ||
|
||||
strings.HasPrefix(session.Model.Value, "claude-3")) {
|
||||
contents := make([]string, 0)
|
||||
var file model.File
|
||||
for _, v := range files {
|
||||
h.DB.Where("url = ?", v).First(&file)
|
||||
content, err := utils.ReadFileContent(v, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
} else {
|
||||
contents = append(contents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
fullPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(contents, "\n"), text)
|
||||
}
|
||||
|
||||
tokens, _ := utils.CalcTokens(fullPrompt, req.Model)
|
||||
if tokens > session.Model.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
}
|
||||
logger.Debug("最终Prompt:", fullPrompt)
|
||||
|
||||
// extract images from prompt
|
||||
imgURLs := utils.ExtractImgURLs(prompt)
|
||||
logger.Debugf("detected IMG: %+v", imgURLs)
|
||||
var content interface{}
|
||||
if len(imgURLs) > 0 {
|
||||
data := make([]interface{}, 0)
|
||||
for _, v := range imgURLs {
|
||||
text = strings.Replace(text, v, "", 1)
|
||||
data = append(data, gin.H{
|
||||
fileContents := make([]string, 0) // 文件内容
|
||||
var finalPrompt = input.Prompt
|
||||
imgList := make([]any, 0)
|
||||
for _, file := range input.Files {
|
||||
logger.Debugf("detected file: %+v", file.URL)
|
||||
// 处理图片
|
||||
if isImageURL(file.URL) {
|
||||
imgList = append(imgList, gin.H{
|
||||
"type": "image_url",
|
||||
"image_url": gin.H{
|
||||
"url": v,
|
||||
"url": file.URL,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// 如果不是逆向模型,则提取文件内容
|
||||
modelValue := input.ChatModel.Value
|
||||
if !(strings.Contains(modelValue, "-all") || strings.HasPrefix(modelValue, "gpt-4-gizmo") || strings.HasPrefix(modelValue, "claude")) {
|
||||
content, err := utils.ReadFileContent(file.URL, h.App.Config.TikaHost)
|
||||
if err != nil {
|
||||
logger.Error("error with read file: ", err)
|
||||
continue
|
||||
} else {
|
||||
fileContents = append(fileContents, fmt.Sprintf("%s 文件内容:%s", file.Name, content))
|
||||
}
|
||||
}
|
||||
}
|
||||
data = append(data, gin.H{
|
||||
"type": "text",
|
||||
"text": strings.TrimSpace(text),
|
||||
})
|
||||
content = data
|
||||
} else {
|
||||
content = fullPrompt
|
||||
}
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
})
|
||||
|
||||
logger.Debugf("%+v", req.Messages)
|
||||
if len(fileContents) > 0 {
|
||||
finalPrompt = fmt.Sprintf("请根据提供的文件内容信息回答问题(其中Excel 已转成 HTML):\n\n %s\n\n 问题:%s", strings.Join(fileContents, "\n"), input.Prompt)
|
||||
tokens, _ := utils.CalcTokens(finalPrompt, req.Model)
|
||||
if tokens > input.ChatModel.MaxContext {
|
||||
return fmt.Errorf("文件的长度超出模型允许的最大上下文长度,请减少文件内容数量或文件大小。")
|
||||
}
|
||||
} else {
|
||||
finalPrompt = input.Prompt
|
||||
}
|
||||
|
||||
return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, c)
|
||||
if len(imgList) > 0 {
|
||||
imgList = append(imgList, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": input.Prompt,
|
||||
})
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": imgList,
|
||||
})
|
||||
} else {
|
||||
req.Messages = append(reqMgs, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": finalPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
logger.Debugf("请求消息: %+v", req.Messages)
|
||||
|
||||
return h.sendOpenAiMessage(req, userVo, ctx, input, c)
|
||||
}
|
||||
|
||||
// 判断一个 URL 是否图片链接
|
||||
func isImageURL(url string) bool {
|
||||
// 检查是否是有效的URL
|
||||
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(path.Ext(url))
|
||||
validImageExts := map[string]bool{
|
||||
".jpg": true,
|
||||
".jpeg": true,
|
||||
".png": true,
|
||||
".gif": true,
|
||||
".bmp": true,
|
||||
".webp": true,
|
||||
".svg": true,
|
||||
".ico": true,
|
||||
}
|
||||
|
||||
if !validImageExts[ext] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 发送HEAD请求检查Content-Type
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
resp, err := client.Head(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
return strings.HasPrefix(contentType, "image/")
|
||||
}
|
||||
|
||||
// Tokens 统计 token 数量
|
||||
@@ -415,10 +441,10 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
|
||||
// 发送请求到 OpenAI 服务器
|
||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, input ChatInput, apiKey *model.ApiKey) (*http.Response, error) {
|
||||
// if the chat model bind a KEY, use it directly
|
||||
if session.Model.KeyId > 0 {
|
||||
h.DB.Where("id", session.Model.KeyId).Find(apiKey)
|
||||
if input.ChatModel.KeyId > 0 {
|
||||
h.DB.Where("id", input.ChatModel.KeyId).Find(apiKey)
|
||||
} else { // use the last unused key
|
||||
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
|
||||
}
|
||||
@@ -472,16 +498,16 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
|
||||
}
|
||||
|
||||
// 扣减用户算力
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
|
||||
func (h *ChatHandler) subUserPower(userVo vo.User, input ChatInput, promptTokens int, replyTokens int) {
|
||||
power := 1
|
||||
if session.Model.Power > 0 {
|
||||
power = session.Model.Power
|
||||
if input.ChatModel.Power > 0 {
|
||||
power = input.ChatModel.Power
|
||||
}
|
||||
|
||||
err := h.userService.DecreasePower(userVo.Id, power, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: session.Model.Value,
|
||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
|
||||
Model: input.ChatModel.Value,
|
||||
Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", input.ChatModel.Name, promptTokens, replyTokens),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
@@ -492,8 +518,7 @@ func (h *ChatHandler) saveChatHistory(
|
||||
req types.ApiRequest,
|
||||
usage Usage,
|
||||
message types.Message,
|
||||
session *types.ChatSession,
|
||||
role model.ChatRole,
|
||||
input ChatInput,
|
||||
userVo vo.User,
|
||||
promptCreatedAt time.Time,
|
||||
replyCreatedAt time.Time) {
|
||||
@@ -502,7 +527,7 @@ func (h *ChatHandler) saveChatHistory(
|
||||
if h.App.SysConfig.EnableContext {
|
||||
chatCtx := req.Messages // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
h.ChatContexts.Put(input.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
@@ -515,12 +540,15 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
|
||||
historyUserMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: template.HTMLEscapeString(usage.Prompt),
|
||||
UserId: userVo.Id,
|
||||
ChatId: input.ChatId,
|
||||
RoleId: input.RoleId,
|
||||
Type: types.PromptMsg,
|
||||
Icon: userVo.Avatar,
|
||||
Content: utils.JsonEncode(vo.MsgContent{
|
||||
Text: usage.Prompt,
|
||||
Files: input.Files,
|
||||
}),
|
||||
Tokens: promptTokens,
|
||||
TotalTokens: promptTokens,
|
||||
UseContext: true,
|
||||
@@ -543,12 +571,15 @@ func (h *ChatHandler) saveChatHistory(
|
||||
totalTokens = replyTokens + getTotalTokens(req)
|
||||
}
|
||||
historyReplyMsg := model.ChatMessage{
|
||||
UserId: userVo.Id,
|
||||
ChatId: session.ChatId,
|
||||
RoleId: role.Id,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: role.Icon,
|
||||
Content: usage.Content,
|
||||
UserId: userVo.Id,
|
||||
ChatId: input.ChatId,
|
||||
RoleId: input.RoleId,
|
||||
Type: types.ReplyMsg,
|
||||
Icon: input.ChatRole.Icon,
|
||||
Content: utils.JsonEncode(vo.MsgContent{
|
||||
Text: message.Content,
|
||||
Files: input.Files,
|
||||
}),
|
||||
Tokens: replyTokens,
|
||||
TotalTokens: totalTokens,
|
||||
UseContext: true,
|
||||
@@ -562,17 +593,17 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
|
||||
// 更新用户算力
|
||||
if session.Model.Power > 0 {
|
||||
h.subUserPower(userVo, session, promptTokens, replyTokens)
|
||||
if input.ChatModel.Power > 0 {
|
||||
h.subUserPower(userVo, input, promptTokens, replyTokens)
|
||||
}
|
||||
// 保存当前会话
|
||||
var chatItem model.ChatItem
|
||||
err = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem).Error
|
||||
err = h.DB.Where("chat_id = ?", input.ChatId).First(&chatItem).Error
|
||||
if err != nil {
|
||||
chatItem.ChatId = session.ChatId
|
||||
chatItem.ChatId = input.ChatId
|
||||
chatItem.UserId = userVo.Id
|
||||
chatItem.RoleId = role.Id
|
||||
chatItem.ModelId = session.Model.Id
|
||||
chatItem.RoleId = input.RoleId
|
||||
chatItem.ModelId = input.ModelId
|
||||
if utf8.RuneCountInString(usage.Prompt) > 30 {
|
||||
chatItem.Title = string([]rune(usage.Prompt)[:30]) + "..."
|
||||
} else {
|
||||
@@ -586,7 +617,7 @@ func (h *ChatHandler) saveChatHistory(
|
||||
}
|
||||
}
|
||||
|
||||
// 文本生成语音
|
||||
// TextToSpeech 文本生成语音
|
||||
func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
var data struct {
|
||||
ModelId int `json:"model_id"`
|
||||
@@ -600,13 +631,19 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
textHash := utils.Sha256(fmt.Sprintf("%d/%s", data.ModelId, data.Text))
|
||||
audioFile := fmt.Sprintf("%s/audio", h.App.Config.StaticDir)
|
||||
if _, err := os.Stat(audioFile); err != nil {
|
||||
os.MkdirAll(audioFile, 0755)
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(audioFile, 0755); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
audioFile = fmt.Sprintf("%s/%s.mp3", audioFile, textHash)
|
||||
if _, err := os.Stat(audioFile); err == nil {
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", "audio/mpeg")
|
||||
c.Header("Content-Disposition", "attachment; filename=speech.mp3")
|
||||
c.Header("Prompt-Type", "audio/mpeg")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
|
||||
c.File(audioFile)
|
||||
return
|
||||
}
|
||||
@@ -670,11 +707,14 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", "audio/mpeg")
|
||||
c.Header("Content-Disposition", "attachment; filename=speech.mp3")
|
||||
c.Header("Prompt-Type", "audio/mpeg")
|
||||
c.Header("Prompt-Disposition", "attachment; filename=speech.mp3")
|
||||
|
||||
// 直接写入完整的音频数据到响应
|
||||
c.Writer.Write(audioBytes)
|
||||
_, err = c.Writer.Write(audioBytes)
|
||||
if err != nil {
|
||||
logger.Error("写入音频数据到响应失败:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// // OPenAI 消息发送实现
|
||||
@@ -707,7 +747,7 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
|
||||
// }
|
||||
|
||||
// contentType := response.Header.Get("Content-Type")
|
||||
// contentType := response.Header.Get("Prompt-Type")
|
||||
// if strings.Contains(contentType, "text/event-stream") {
|
||||
// replyCreatedAt := time.Now() // 记录回复时间
|
||||
// // 循环读取 Chunk 消息
|
||||
@@ -733,7 +773,7 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
|
||||
// continue
|
||||
// }
|
||||
// if responseBody.Choices[0].Delta.Content == nil &&
|
||||
// if responseBody.Choices[0].Delta.Prompt == nil &&
|
||||
// responseBody.Choices[0].Delta.ToolCalls == nil &&
|
||||
// responseBody.Choices[0].Delta.ReasoningContent == "" {
|
||||
// continue
|
||||
@@ -799,10 +839,10 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
// "content": reasoningContent,
|
||||
// })
|
||||
// contents = append(contents, reasoningContent)
|
||||
// } else if responseBody.Choices[0].Delta.Content != "" {
|
||||
// finalContent := responseBody.Choices[0].Delta.Content
|
||||
// } else if responseBody.Choices[0].Delta.Prompt != "" {
|
||||
// finalContent := responseBody.Choices[0].Delta.Prompt
|
||||
// if reasoning {
|
||||
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content)
|
||||
// finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Prompt)
|
||||
// reasoning = false
|
||||
// }
|
||||
// contents = append(contents, utils.InterfaceToString(finalContent))
|
||||
@@ -861,12 +901,12 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
// if len(contents) > 0 {
|
||||
// usage := Usage{
|
||||
// Prompt: prompt,
|
||||
// Content: strings.Join(contents, ""),
|
||||
// Prompt: strings.Join(contents, ""),
|
||||
// PromptTokens: 0,
|
||||
// CompletionTokens: 0,
|
||||
// TotalTokens: 0,
|
||||
// }
|
||||
// message.Content = usage.Content
|
||||
// message.Prompt = usage.Prompt
|
||||
// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
|
||||
// }
|
||||
// } else {
|
||||
@@ -879,16 +919,16 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("解析响应失败:%v", body)
|
||||
// }
|
||||
// content := respVo.Choices[0].Message.Content
|
||||
// content := respVo.Choices[0].Message.Prompt
|
||||
// if strings.HasPrefix(req.Model, "o1-") {
|
||||
// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||
// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Prompt)
|
||||
// }
|
||||
// pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
|
||||
// "type": "text",
|
||||
// "content": content,
|
||||
// })
|
||||
// respVo.Usage.Prompt = prompt
|
||||
// respVo.Usage.Content = content
|
||||
// respVo.Usage.Prompt = content
|
||||
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
|
||||
// }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user