SSE 替换 websocket

This commit is contained in:
GeekMaster
2025-05-26 18:26:36 +08:00
parent 76a3ada85f
commit 41e4b1c7ac
8 changed files with 808 additions and 723 deletions

View File

@@ -8,6 +8,7 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -32,10 +33,30 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
req2 "github.com/imroc/req/v3"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
ChatEventStart = "start"
ChatEventEnd = "end"
ChatEventError = "error"
ChatEventMessageDelta = "message_delta"
ChatEventTitle = "title"
)
type ChatInput struct {
UserId uint `json:"user_id"`
RoleId int `json:"role_id"`
ModelId int `json:"model_id"`
ChatId string `json:"chat_id"`
Content string `json:"content"`
Tools []int `json:"tools"`
Stream bool `json:"stream"`
Files []vo.File `json:"files"`
}
type ChatHandler struct { type ChatHandler struct {
BaseHandler BaseHandler
redis *redis.Client redis *redis.Client
@@ -58,7 +79,89 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
} }
} }
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { // Chat 处理聊天请求
func (h *ChatHandler) Chat(c *gin.Context) {
var data ChatInput
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
// 验证聊天角色
var chatRole model.ChatRole
err := h.DB.First(&chatRole, data.RoleId).Error
if err != nil || !chatRole.Enable {
pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!")
return
}
// 如果角色绑定了模型ID使用角色的模型ID
if chatRole.ModelId > 0 {
data.ModelId = int(chatRole.ModelId)
}
// 获取模型信息
var chatModel model.ChatModel
err = h.DB.Where("id", data.ModelId).First(&chatModel).Error
if err != nil || !chatModel.Enabled {
pushMessage(c, ChatEventError, "当前AI模型暂未启用请更换模型后再发起对话")
return
}
session := &types.ChatSession{
ClientIP: c.ClientIP(),
UserId: data.UserId,
ChatId: data.ChatId,
Tools: data.Tools,
Stream: data.Stream,
Model: types.ChatModel{
KeyId: data.ModelId,
},
}
// 使用旧的聊天数据覆盖模型和角色ID
var chat model.ChatItem
h.DB.Where("chat_id", data.ChatId).First(&chat)
if chat.Id > 0 {
chatModel.Id = chat.ModelId
data.RoleId = int(chat.RoleId)
}
// 复制模型数据
err = utils.CopyObject(chatModel, &session.Model)
if err != nil {
logger.Error(err, chatModel)
}
session.Model.Id = chatModel.Id
// 发送消息
err = h.sendMessage(ctx, session, chatRole, data.Content, c)
if err != nil {
pushMessage(c, ChatEventError, err.Error())
return
}
pushMessage(c, ChatEventEnd, "对话完成")
}
func pushMessage(c *gin.Context, msgType string, content interface{}) {
c.SSEvent("message", map[string]interface{}{
"type": msgType,
"body": content,
})
c.Writer.Flush()
}
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, c *gin.Context) error {
var user model.User var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId) res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil { if res.Error != nil {
@@ -254,7 +357,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
logger.Debugf("%+v", req.Messages) logger.Debugf("%+v", req.Messages)
return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws) return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, c)
} }
// Tokens 统计 token 数量 // Tokens 统计 token 数量
@@ -584,3 +687,221 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) {
// 直接写入完整的音频数据到响应 // 直接写入完整的音频数据到响应
c.Writer.Write(audioBytes) c.Writer.Write(audioBytes)
} }
// OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage(
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
c *gin.Context) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Since(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{Role: "assistant"}
var contents = make([]string, 0)
var function model.Function
var toolCall = false
var arguments = make([]string, 0)
var reasoning = false
pushMessage(c, ChatEventStart, "开始响应")
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return errors.New(line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].Delta.Content == nil &&
responseBody.Choices[0].Delta.ToolCalls == nil &&
responseBody.Choices[0].Delta.ReasoningContent == "" {
continue
}
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
pushMessage(c, ChatEventError, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break
}
var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && tool.Function.Name == "" {
arguments = append(arguments, tool.Function.Arguments)
continue
}
}
// 兼容 Function Call
fun := responseBody.Choices[0].Delta.FunctionCall
if fun.Name != "" {
tool = *new(types.ToolCall)
tool.Function.Name = fun.Name
} else if toolCall {
arguments = append(arguments, fun.Arguments)
continue
}
if !utils.IsEmptyValue(tool) {
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
"type": "text",
"content": callMsg,
})
contents = append(contents, callMsg)
}
continue
}
if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break
}
// output stopped
if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else { // 正常输出结果
// 兼容思考过程
if responseBody.Choices[0].Delta.ReasoningContent != "" {
reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
if !reasoning {
reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
reasoning = true
}
pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
"type": "text",
"content": reasoningContent,
})
contents = append(contents, reasoningContent)
} else if responseBody.Choices[0].Delta.Content != "" {
finalContent := responseBody.Choices[0].Delta.Content
if reasoning {
finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content)
reasoning = false
}
contents = append(contents, utils.InterfaceToString(finalContent))
pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
"type": "text",
"content": finalContent,
})
}
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
if toolCall { // 调用函数完成任务
params := make(map[string]any)
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id
var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", function.Token).
SetBody(params).Post(function.Action)
errMsg := ""
if err != nil {
errMsg = err.Error()
} else {
all, _ := io.ReadAll(r.Body)
err = json.Unmarshal(all, &apiRes)
if err != nil {
errMsg = err.Error()
} else if apiRes.Code != types.Success {
errMsg = apiRes.Message
}
}
if errMsg != "" {
errMsg = "调用函数工具出错:" + errMsg
contents = append(contents, errMsg)
} else {
errMsg = utils.InterfaceToString(apiRes.Data)
contents = append(contents, errMsg)
}
pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
"type": "text",
"content": errMsg,
})
}
// 消息发送成功
if len(contents) > 0 {
usage := Usage{
Prompt: prompt,
Content: strings.Join(contents, ""),
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
message.Content = usage.Content
h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
var respVo OpenAIResVo
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败:%v", body)
}
err = json.Unmarshal(body, &respVo)
if err != nil {
return fmt.Errorf("解析响应失败:%v", body)
}
content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
pushMessage(c, ChatEventMessageDelta, map[string]interface{}{
"type": "text",
"content": content,
})
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
}
return nil
}

View File

@@ -1,253 +1,271 @@
package handler package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved. // // * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license // // * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file. // // * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com // // * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( // import (
"bufio" // "bufio"
"context" // "context"
"encoding/json" // "encoding/json"
"errors" // "errors"
"fmt" // "fmt"
"geekai/core/types" // "geekai/core/types"
"geekai/store/model" // "geekai/store/model"
"geekai/store/vo" // "geekai/store/vo"
"geekai/utils" // "geekai/utils"
"io" // "io"
"strings" // "strings"
"time" // "time"
req2 "github.com/imroc/req/v3" // req2 "github.com/imroc/req/v3"
) // )
type Usage struct { // type Usage struct {
Prompt string `json:"prompt,omitempty"` // Prompt string `json:"prompt,omitempty"`
Content string `json:"content,omitempty"` // Content string `json:"content,omitempty"`
PromptTokens int `json:"prompt_tokens"` // PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` // CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` // TotalTokens int `json:"total_tokens"`
} // }
type OpenAIResVo struct { // type OpenAIResVo struct {
Id string `json:"id"` // Id string `json:"id"`
Object string `json:"object"` // Object string `json:"object"`
Created int `json:"created"` // Created int `json:"created"`
Model string `json:"model"` // Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"` // SystemFingerprint string `json:"system_fingerprint"`
Choices []struct { // Choices []struct {
Index int `json:"index"` // Index int `json:"index"`
Message struct { // Message struct {
Role string `json:"role"` // Role string `json:"role"`
Content string `json:"content"` // Content string `json:"content"`
} `json:"message"` // } `json:"message"`
Logprobs interface{} `json:"logprobs"` // Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"` // FinishReason string `json:"finish_reason"`
} `json:"choices"` // } `json:"choices"`
Usage Usage `json:"usage"` // Usage Usage `json:"usage"`
} // }
// OPenAI 消息发送实现 // // OPenAI 消息发送实现
func (h *ChatHandler) sendOpenAiMessage( // func (h *ChatHandler) sendOpenAiMessage(
req types.ApiRequest, // req types.ApiRequest,
userVo vo.User, // userVo vo.User,
ctx context.Context, // ctx context.Context,
session *types.ChatSession, // session *types.ChatSession,
role model.ChatRole, // role model.ChatRole,
prompt string, // prompt string,
ws *types.WsClient) error { // messageChan chan interface{}) error {
promptCreatedAt := time.Now() // 记录提问时间 // promptCreatedAt := time.Now() // 记录提问时间
start := time.Now() // start := time.Now()
var apiKey = model.ApiKey{} // var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey) // response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Since(start)) // logger.Info("HTTP请求完成耗时", time.Since(start))
if err != nil { // if err != nil {
if strings.Contains(err.Error(), "context canceled") { // if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt) // return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") { // } else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员") // return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
} // }
return err // return err
} else { // } else {
defer response.Body.Close() // defer response.Body.Close()
} // }
if response.StatusCode != 200 { // if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body) // body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body)) // return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
} // }
contentType := response.Header.Get("Content-Type") // contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") { // if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间 // replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息 // // 循环读取 Chunk 消息
var message = types.Message{Role: "assistant"} // var message = types.Message{Role: "assistant"}
var contents = make([]string, 0) // var contents = make([]string, 0)
var function model.Function // var function model.Function
var toolCall = false // var toolCall = false
var arguments = make([]string, 0) // var arguments = make([]string, 0)
var reasoning = false // var reasoning = false
scanner := bufio.NewScanner(response.Body) // scanner := bufio.NewScanner(response.Body)
for scanner.Scan() { // for scanner.Scan() {
line := scanner.Text() // line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 { // if !strings.Contains(line, "data:") || len(line) < 30 {
continue // continue
} // }
var responseBody = types.ApiResponse{} // var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody) // err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错 // if err != nil { // 数据解析出错
return errors.New(line) // return errors.New(line)
} // }
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 // if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue // continue
} // }
if responseBody.Choices[0].Delta.Content == nil && // if responseBody.Choices[0].Delta.Content == nil &&
responseBody.Choices[0].Delta.ToolCalls == nil && // responseBody.Choices[0].Delta.ToolCalls == nil &&
responseBody.Choices[0].Delta.ReasoningContent == "" { // responseBody.Choices[0].Delta.ReasoningContent == "" {
continue // continue
} // }
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { // if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.SendChunkMsg(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。") // messageChan <- map[string]interface{}{
break // "type": "text",
} // "body": "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。",
// }
// break
// }
var tool types.ToolCall // var tool types.ToolCall
if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { // if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
tool = responseBody.Choices[0].Delta.ToolCalls[0] // tool = responseBody.Choices[0].Delta.ToolCalls[0]
if toolCall && tool.Function.Name == "" { // if toolCall && tool.Function.Name == "" {
arguments = append(arguments, tool.Function.Arguments) // arguments = append(arguments, tool.Function.Arguments)
continue // continue
} // }
} // }
// 兼容 Function Call // // 兼容 Function Call
fun := responseBody.Choices[0].Delta.FunctionCall // fun := responseBody.Choices[0].Delta.FunctionCall
if fun.Name != "" { // if fun.Name != "" {
tool = *new(types.ToolCall) // tool = *new(types.ToolCall)
tool.Function.Name = fun.Name // tool.Function.Name = fun.Name
} else if toolCall { // } else if toolCall {
arguments = append(arguments, fun.Arguments) // arguments = append(arguments, fun.Arguments)
continue // continue
} // }
if !utils.IsEmptyValue(tool) { // if !utils.IsEmptyValue(tool) {
res := h.DB.Where("name = ?", tool.Function.Name).First(&function) // res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil { // if res.Error == nil {
toolCall = true // toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) // callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.SendChunkMsg(ws, callMsg) // messageChan <- map[string]interface{}{
contents = append(contents, callMsg) // "type": "text",
} // "body": callMsg,
continue // }
} // contents = append(contents, callMsg)
// }
// continue
// }
if responseBody.Choices[0].FinishReason == "tool_calls" || // if responseBody.Choices[0].FinishReason == "tool_calls" ||
responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 // responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
break // break
} // }
// output stopped // // output stopped
if responseBody.Choices[0].FinishReason != "" { // if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了 // break // 输出完成或者输出中断了
} else { // 正常输出结果 // } else { // 正常输出结果
// 兼容思考过程 // // 兼容思考过程
if responseBody.Choices[0].Delta.ReasoningContent != "" { // if responseBody.Choices[0].Delta.ReasoningContent != "" {
reasoningContent := responseBody.Choices[0].Delta.ReasoningContent // reasoningContent := responseBody.Choices[0].Delta.ReasoningContent
if !reasoning { // if !reasoning {
reasoningContent = fmt.Sprintf("<think>%s", reasoningContent) // reasoningContent = fmt.Sprintf("<think>%s", reasoningContent)
reasoning = true // reasoning = true
} // }
utils.SendChunkMsg(ws, reasoningContent) // messageChan <- map[string]interface{}{
contents = append(contents, reasoningContent) // "type": "text",
} else if responseBody.Choices[0].Delta.Content != "" { // "body": reasoningContent,
finalContent := responseBody.Choices[0].Delta.Content // }
if reasoning { // contents = append(contents, reasoningContent)
finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content) // } else if responseBody.Choices[0].Delta.Content != "" {
reasoning = false // finalContent := responseBody.Choices[0].Delta.Content
} // if reasoning {
contents = append(contents, utils.InterfaceToString(finalContent)) // finalContent = fmt.Sprintf("</think>%s", responseBody.Choices[0].Delta.Content)
utils.SendChunkMsg(ws, finalContent) // reasoning = false
} // }
} // contents = append(contents, utils.InterfaceToString(finalContent))
} // end for // messageChan <- map[string]interface{}{
// "type": "text",
// "body": finalContent,
// }
// }
// }
// } // end for
if err := scanner.Err(); err != nil { // if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") { // if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt) // logger.Info("用户取消了请求:", prompt)
} else { // } else {
logger.Error("信息读取出错:", err) // logger.Error("信息读取出错:", err)
} // }
} // }
if toolCall { // 调用函数完成任务 // if toolCall { // 调用函数完成任务
params := make(map[string]any) // params := make(map[string]any)
_ = utils.JsonDecode(strings.Join(arguments, ""), &params) // _ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) // logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
params["user_id"] = userVo.Id // params["user_id"] = userVo.Id
var apiRes types.BizVo // var apiRes types.BizVo
r, err := req2.C().R().SetHeader("Body-Type", "application/json"). // r, err := req2.C().R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", function.Token). // SetHeader("Authorization", function.Token).
SetBody(params).Post(function.Action) // SetBody(params).Post(function.Action)
errMsg := "" // errMsg := ""
if err != nil { // if err != nil {
errMsg = err.Error() // errMsg = err.Error()
} else { // } else {
all, _ := io.ReadAll(r.Body) // all, _ := io.ReadAll(r.Body)
err = json.Unmarshal(all, &apiRes) // err = json.Unmarshal(all, &apiRes)
if err != nil { // if err != nil {
errMsg = err.Error() // errMsg = err.Error()
} else if apiRes.Code != types.Success { // } else if apiRes.Code != types.Success {
errMsg = apiRes.Message // errMsg = apiRes.Message
} // }
} // }
if errMsg != "" { // if errMsg != "" {
errMsg = "调用函数工具出错:" + errMsg // errMsg = "调用函数工具出错:" + errMsg
contents = append(contents, errMsg) // contents = append(contents, errMsg)
} else { // } else {
errMsg = utils.InterfaceToString(apiRes.Data) // errMsg = utils.InterfaceToString(apiRes.Data)
contents = append(contents, errMsg) // contents = append(contents, errMsg)
} // }
utils.SendChunkMsg(ws, errMsg) // messageChan <- map[string]interface{}{
} // "type": "text",
// "body": errMsg,
// }
// }
// 消息发送成功 // // 消息发送成功
if len(contents) > 0 { // if len(contents) > 0 {
usage := Usage{ // usage := Usage{
Prompt: prompt, // Prompt: prompt,
Content: strings.Join(contents, ""), // Content: strings.Join(contents, ""),
PromptTokens: 0, // PromptTokens: 0,
CompletionTokens: 0, // CompletionTokens: 0,
TotalTokens: 0, // TotalTokens: 0,
} // }
message.Content = usage.Content // message.Content = usage.Content
h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt) // h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt)
} // }
} else { // 非流式输出 // } else { // 非流式输出
var respVo OpenAIResVo // var respVo OpenAIResVo
body, err := io.ReadAll(response.Body) // body, err := io.ReadAll(response.Body)
if err != nil { // if err != nil {
return fmt.Errorf("读取响应失败:%v", body) // return fmt.Errorf("读取响应失败:%v", body)
} // }
err = json.Unmarshal(body, &respVo) // err = json.Unmarshal(body, &respVo)
if err != nil { // if err != nil {
return fmt.Errorf("解析响应失败:%v", body) // return fmt.Errorf("解析响应失败:%v", body)
} // }
content := respVo.Choices[0].Message.Content // content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") { // if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) // content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
} // }
utils.SendChunkMsg(ws, content) // messageChan <- map[string]interface{}{
respVo.Usage.Prompt = prompt // "type": "text",
respVo.Usage.Content = content // "body": content,
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now()) // }
} // respVo.Usage.Prompt = prompt
// respVo.Usage.Content = content
// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now())
// }
return nil // return nil
} // }

View File

@@ -1,153 +0,0 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
// Websocket 连接处理 handler
type WebsocketHandler struct {
BaseHandler
wsService *service.WebsocketService
chatHandler *ChatHandler
}
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
return &WebsocketHandler{
BaseHandler: BaseHandler{App: app, DB: db},
chatHandler: chatHandler,
wsService: s,
}
}
func (h *WebsocketHandler) Client(c *gin.Context) {
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: strings.Split(clientProtocols, ","),
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
clientId := c.Query("client_id")
client := types.NewWsClient(ws, clientId)
userId := h.GetLoginUserId(c)
if userId == 0 {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
h.wsService.Clients.Put(clientId, client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
h.wsService.Clients.Delete(clientId)
break
}
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
logger.Debugf("Receive a message:%+v", message)
if message.Type == types.MsgTypePing {
utils.SendChannelMsg(client, types.ChPing, "pong")
continue
}
// 当前只处理聊天消息,其他消息全部丢弃
var chatMessage types.ChatMessage
err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
if err != nil || message.Channel != types.ChChat {
logger.Warnf("invalid message body:%+v", message.Body)
continue
}
var chatRole model.ChatRole
err = h.DB.First(&chatRole, chatMessage.RoleId).Error
if err != nil || !chatRole.Enable {
utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
continue
}
// if the role bind a model_id, use role's bind model_id
if chatRole.ModelId > 0 {
chatMessage.ModelId = int(chatRole.ModelId)
}
// get model info
var chatModel model.ChatModel
err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
if err != nil || !chatModel.Enabled {
utils.SendAndFlush(client, "当前AI模型暂未启用请更换模型后再发起对话")
continue
}
session := &types.ChatSession{
ClientIP: c.ClientIP(),
UserId: userId,
}
// use old chat data override the chat model and role ID
var chat model.ChatItem
h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
if chat.Id > 0 {
chatModel.Id = chat.ModelId
chatMessage.RoleId = int(chat.RoleId)
}
session.ChatId = chatMessage.ChatId
session.Tools = chatMessage.Tools
session.Stream = chatMessage.Stream
session.Model.KeyId = chatMessage.ModelId
// 复制模型数据
err = utils.CopyObject(chatModel, &session.Model)
if err != nil {
logger.Error(err, chatModel)
}
session.Model.Id = chatModel.Id
ctx, cancel := context.WithCancel(context.Background())
h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
if err != nil {
logger.Error(err)
utils.SendAndFlush(client, err.Error())
} else {
utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
logger.Infof("回答完毕: %v", message.Body)
}
}
}()
}

View File

@@ -243,6 +243,7 @@ func main() {
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
group := s.Engine.Group("/api/chat/") group := s.Engine.Group("/api/chat/")
group.Any("message", h.Chat)
group.GET("list", h.List) group.GET("list", h.List)
group.GET("detail", h.Detail) group.GET("detail", h.Detail)
group.POST("update", h.Update) group.POST("update", h.Update)
@@ -515,10 +516,6 @@ func main() {
group.Any("sse", h.PostTest, h.SseTest) group.Any("sse", h.PostTest, h.SseTest)
}), }),
fx.Provide(service.NewWebsocketService), fx.Provide(service.NewWebsocketService),
fx.Provide(handler.NewWebsocketHandler),
fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
s.Engine.Any("/api/ws", h.Client)
}),
fx.Provide(handler.NewPromptHandler), fx.Provide(handler.NewPromptHandler),
fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) {
group := s.Engine.Group("/api/prompt") group := s.Engine.Group("/api/prompt")

View File

@@ -45,8 +45,8 @@
"vue": "^3.2.13", "vue": "^3.2.13",
"vue-router": "^4.0.15", "vue-router": "^4.0.15",
"unplugin-auto-import": "^0.18.5", "unplugin-auto-import": "^0.18.5",
"@microsoft/fetch-event-source": "^2.0.1",
"vue-waterfall-plugin-next": "^2.6.5" "vue-waterfall-plugin-next": "^2.6.5"
}, },
"devDependencies": { "devDependencies": {
"@vitejs/plugin-vue": "^5.2.4", "@vitejs/plugin-vue": "^5.2.4",

View File

@@ -5,13 +5,12 @@
</template> </template>
<script setup> <script setup>
import { checkSession, getClientId, getSystemInfo } from '@/store/cache' import { checkSession, getSystemInfo } from '@/store/cache'
import { getUserToken } from '@/store/session'
import { useSharedStore } from '@/store/sharedata' import { useSharedStore } from '@/store/sharedata'
import { showMessageInfo } from '@/utils/dialog' import { showMessageInfo } from '@/utils/dialog'
import { isChrome, isMobile } from '@/utils/libs' import { isChrome, isMobile } from '@/utils/libs'
import { ElConfigProvider } from 'element-plus' import { ElConfigProvider } from 'element-plus'
import { onMounted, ref, watch } from 'vue' import { onMounted } from 'vue'
const debounce = (fn, delay) => { const debounce = (fn, delay) => {
let timer let timer
@@ -56,46 +55,8 @@ onMounted(() => {
document.documentElement.setAttribute('data-theme', store.theme) document.documentElement.setAttribute('data-theme', store.theme)
}) })
watch(
() => store.isLogin,
(val) => {
if (val) {
connect()
}
}
)
const handler = ref(0)
// 初始化 websocket 连接
const connect = () => {
let host = import.meta.env.VITE_WS_HOST
if (host === '') {
if (location.protocol === 'https:') {
host = 'wss://' + location.host
} else {
host = 'ws://' + location.host
}
}
const clientId = getClientId()
const _socket = new WebSocket(host + `/api/ws?client_id=${clientId}`, ['token', getUserToken()])
_socket.addEventListener('open', () => {
console.log('WebSocket 已连接')
handler.value = setInterval(() => {
if (_socket.readyState === WebSocket.OPEN) {
_socket.send(JSON.stringify({ type: 'ping' }))
}
}, 5000)
})
_socket.addEventListener('close', () => {
clearInterval(handler.value)
connect()
})
store.setSocket(_socket)
}
// 打印 banner // 打印 banner
const banner = ` const banner = `
.oooooo. oooo .o. ooooo .oooooo. oooo .o. ooooo
d8P' 'Y8b 888 .888. 888 d8P' 'Y8b 888 .888. 888
888 .ooooo. .ooooo. 888 oooo .8"888. 888 888 .ooooo. .ooooo. 888 oooo .8"888. 888
@@ -103,7 +64,6 @@ const banner = `
888 ooooo 888ooo888 888ooo888 888888. .88ooo8888. 888 888 ooooo 888ooo888 888ooo888 888888. .88ooo8888. 888
'88. .88' 888 .o 888 .o 888 88b. .8' 888. 888 '88. .88' 888 .o 888 .o 888 88b. .8' 888. 888
Y8bood8P' Y8bod8P' Y8bod8P' o888o o888o o88o o8888o o888o Y8bood8P' Y8bod8P' Y8bod8P' o888o o888o o88o o8888o o888o
` `
console.log('%c' + banner + '', 'color: purple;font-size: 18px;') console.log('%c' + banner + '', 'color: purple;font-size: 18px;')

View File

@@ -1,11 +1,11 @@
import {defineStore} from "pinia"; import errorIcon from '@/assets/img/failed.png'
import Storage from "good-storage"; import loadingIcon from '@/assets/img/loading.gif'
import errorIcon from "@/assets/img/failed.png"; import Storage from 'good-storage'
import loadingIcon from "@/assets/img/loading.gif"; import { defineStore } from 'pinia'
let waterfallOptions = { let waterfallOptions = {
// 唯一key值 // 唯一key值
rowKey: "id", rowKey: 'id',
// 卡片之间的间隙 // 卡片之间的间隙
gutter: 10, gutter: 10,
// 是否有周围的gutter // 是否有周围的gutter
@@ -44,16 +44,16 @@ let waterfallOptions = {
}, },
}, },
// 动画效果 // 动画效果
animationEffect: "animate__fadeInUp", animationEffect: 'animate__fadeInUp',
// 动画时间 // 动画时间
animationDuration: 1000, animationDuration: 1000,
// 动画延迟 // 动画延迟
animationDelay: 300, animationDelay: 300,
animationCancel: false, animationCancel: false,
// 背景色 // 背景色
backgroundColor: "", backgroundColor: '',
// imgSelector // imgSelector
imgSelector: "img_thumb", imgSelector: 'img_thumb',
// 是否跨域 // 是否跨域
crossOrigin: true, crossOrigin: true,
// 加载配置 // 加载配置
@@ -61,102 +61,62 @@ let waterfallOptions = {
loading: loadingIcon, loading: loadingIcon,
error: errorIcon, error: errorIcon,
ratioCalculator: (width, height) => { ratioCalculator: (width, height) => {
const minRatio = 3 / 4; const minRatio = 3 / 4
const maxRatio = 4 / 3; const maxRatio = 4 / 3
const curRatio = height / width; const curRatio = height / width
if (curRatio < minRatio) { if (curRatio < minRatio) {
return minRatio; return minRatio
} else if (curRatio > maxRatio) { } else if (curRatio > maxRatio) {
return maxRatio; return maxRatio
} else { } else {
return curRatio; return curRatio
} }
}, },
}, },
// 是否懒加载 // 是否懒加载
lazyload: true, lazyload: true,
align: "center", align: 'center',
} }
export const useSharedStore = defineStore("shared", { export const useSharedStore = defineStore('shared', {
state: () => ({ state: () => ({
showLoginDialog: false, showLoginDialog: false,
chatListStyle: Storage.get("chat_list_style", "chat"), chatListStyle: Storage.get('chat_list_style', 'chat'),
chatStream: Storage.get("chat_stream", true), chatStream: Storage.get('chat_stream', true),
socket: { conn: null, handlers: {} }, theme: Storage.get('theme', 'light'),
theme: Storage.get("theme", "light"),
isLogin: false, isLogin: false,
chatListExtend: Storage.get("chat_list_extend", true), chatListExtend: Storage.get('chat_list_extend', true),
ttsModel: Storage.get("tts_model", ""), ttsModel: Storage.get('tts_model', ''),
waterfallOptions, waterfallOptions,
}), }),
getters: {}, getters: {},
actions: { actions: {
setShowLoginDialog(value) { setShowLoginDialog(value) {
this.showLoginDialog = value; this.showLoginDialog = value
}, },
setChatListStyle(value) { setChatListStyle(value) {
this.chatListStyle = value; this.chatListStyle = value
Storage.set("chat_list_style", value); Storage.set('chat_list_style', value)
}, },
setChatStream(value) { setChatStream(value) {
this.chatStream = value; this.chatStream = value
Storage.set("chat_stream", value); Storage.set('chat_stream', value)
},
setSocket(value) {
for (const key in this.socket.handlers) {
this.setMessageHandler(value, this.socket.handlers[key]);
}
this.socket.conn = value;
}, },
setChatListExtend(value) { setChatListExtend(value) {
this.chatListExtend = value; this.chatListExtend = value
Storage.set("chat_list_extend", value); Storage.set('chat_list_extend', value)
},
addMessageHandler(key, callback) {
if (!this.socket.handlers[key]) {
this.socket.handlers[key] = callback;
}
this.setMessageHandler(this.socket.conn, callback);
},
setMessageHandler(conn, callback) {
if (!conn) {
return;
}
conn.addEventListener("message", (event) => {
try {
if (event.data instanceof Blob) {
const reader = new FileReader();
reader.readAsText(event.data, "UTF-8");
reader.onload = () => {
callback(JSON.parse(String(reader.result)));
};
}
} catch (e) {
console.warn(e);
}
});
},
removeMessageHandler(key) {
if (this.socket.conn && this.socket.conn.readyState === WebSocket.OPEN) {
this.socket.conn.removeEventListener("message", this.socket.handlers[key]);
}
delete this.socket.handlers[key];
}, },
setTheme(theme) { setTheme(theme) {
this.theme = theme; this.theme = theme
document.documentElement.setAttribute("data-theme", theme); // 设置 HTML 的 data-theme 属性 document.documentElement.setAttribute('data-theme', theme) // 设置 HTML 的 data-theme 属性
Storage.set("theme", theme); Storage.set('theme', theme)
}, },
setIsLogin(value) { setIsLogin(value) {
this.isLogin = value; this.isLogin = value
}, },
setTtsModel(value) { setTtsModel(value) {
this.ttsModel = value; this.ttsModel = value
Storage.set("tts_model", value); Storage.set('tts_model', value)
}, },
}, },
}); })

View File

@@ -147,7 +147,7 @@
type="info" type="info"
style="margin-left: 8px; flex-shrink: 0" style="margin-left: 8px; flex-shrink: 0"
> >
{{ getSelectedModel()?.power }}算力 {{ getSelectedModel() && getSelectedModel().power }}算力
</el-tag> </el-tag>
</div> </div>
</el-button> </el-button>
@@ -320,7 +320,10 @@
<span class="tool-item-btn"> <span class="tool-item-btn">
<el-tooltip class="box-item" effect="dark" content="上传附件"> <el-tooltip class="box-item" effect="dark" content="上传附件">
<file-select :user-id="loginUser?.id" @selected="insertFile" /> <file-select
:user-id="loginUser && loginUser.id"
@selected="insertFile"
/>
</el-tooltip> </el-tooltip>
</span> </span>
</div> </div>
@@ -419,13 +422,15 @@ import {
Share, Share,
VideoPause, VideoPause,
} from '@element-plus/icons-vue' } from '@element-plus/icons-vue'
import { fetchEventSource } from '@microsoft/fetch-event-source'
import Clipboard from 'clipboard' import Clipboard from 'clipboard'
import { ElMessage, ElMessageBox } from 'element-plus' import { ElMessage, ElMessageBox } from 'element-plus'
import 'highlight.js/styles/a11y-dark.css' import 'highlight.js/styles/a11y-dark.css'
import MarkdownIt from 'markdown-it' import MarkdownIt from 'markdown-it'
import emoji from 'markdown-it-emoji' import emoji from 'markdown-it-emoji'
import { computed, nextTick, onMounted, onUnmounted, ref, watch } from 'vue' import { computed, nextTick, onMounted, ref, watch } from 'vue'
import { useRouter } from 'vue-router' import { useRouter } from 'vue-router'
import { getUserToken } from '../store/session'
const title = ref('GeekAI-智能助手') const title = ref('GeekAI-智能助手')
const logo = ref('') const logo = ref('')
@@ -699,153 +704,209 @@ onMounted(() => {
}) })
window.onresize = () => resizeElement() window.onresize = () => resizeElement()
store.addMessageHandler('chat', (data) => { })
// 丢去非本频道和本客户端的消息
if (data.channel !== 'chat' || data.clientId !== getClientId()) { // 初始化数据
return const initData = async () => {
try {
// 获取用户信息
const user = await checkSession()
loginUser.value = user
isLogin.value = true
// 获取角色列表
const roleRes = await httpGet('/api/app/list')
roles.value = roleRes.data
if (roles.value.length > 0) {
roleId.value = roles.value[0].id
} }
if (data.type === 'error') { // 获取模型列表
ElMessage.error(data.body) const modelRes = await httpGet('/api/model/list')
return models.value = modelRes.data
if (models.value.length > 0) {
modelID.value = models.value[0].id
} }
const chatRole = getRoleById(roleId.value) // 获取聊天列表
if (isNewMsg.value && data.type !== 'end') { const chatRes = await httpGet('/api/chat/list')
const prePrompt = chatData.value[chatData.value.length - 1]?.content allChats.value = chatRes.data
isNewMsg.value = false chatList.value = allChats.value
lineBuffer.value = data.body if (chatId.value) {
const reply = chatData.value[chatData.value.length - 1] loadChatHistory(chatId.value)
if (reply) { }
reply['content'] = lineBuffer.value } catch (error) {
} if (error.response?.status === 401) {
} else if (data.type === 'end') { isLogin.value = false
// 消息接收完毕 } else {
// 追加当前会话到会话列表 showMessageError('初始化数据失败:' + error.message)
if (newChatItem.value !== null) { }
newChatItem.value['title'] = tmpChatTitle.value }
newChatItem.value['chat_id'] = chatId.value }
chatList.value.unshift(newChatItem.value)
newChatItem.value = null // 只追加一次
}
enableInput() // 发送消息
lineBuffer.value = '' // 清空缓冲 const sendMessage = async function () {
if (!isLogin.value) {
console.log('未登录')
store.setShowLoginDialog(true)
return
}
// 获取 token if (canSend.value === false) {
const reply = chatData.value[chatData.value.length - 1] ElMessage.warning('AI 正在作答中,请稍后...')
httpPost('/api/chat/tokens', { return
text: '', }
model: getModelValue(modelID.value),
if (prompt.value.trim().length === 0 || canSend.value === false) {
showMessageError('请输入要发送的消息!')
return false
}
// 如果携带了文件,则串上文件地址
let content = prompt.value
if (files.value.length > 0) {
content += files.value.map((file) => file.url).join(' ')
}
// 追加消息
chatData.value.push({
type: 'prompt',
id: randString(32),
icon: loginUser.value.avatar,
content: content,
model: getModelValue(modelID.value),
created_at: new Date().getTime() / 1000,
})
// 添加空回复消息
const _role = getRoleById(roleId.value)
chatData.value.push({
chat_id: chatId,
role_id: roleId.value,
type: 'reply',
id: randString(32),
icon: _role['icon'],
content: '',
})
nextTick(() => {
document
.getElementById('chat-box')
.scrollTo(0, document.getElementById('chat-box').scrollHeight)
})
showHello.value = false
disableInput(false)
try {
await fetchEventSource('/api/chat/message', {
method: 'POST',
headers: {
Authorization: getUserToken(),
},
body: JSON.stringify({
user_id: loginUser.value.id,
role_id: roleId.value,
model_id: modelID.value,
chat_id: chatId.value, chat_id: chatId.value,
}) content: content,
.then((res) => { tools: toolSelected.value,
reply['created_at'] = new Date().getTime() stream: stream.value,
reply['tokens'] = res.data }),
openWhenHidden: true,
onopen(response) {
if (response.ok && response.status === 200) {
console.log('SSE connection opened')
} else {
throw new Error(`Failed to open SSE connection: ${response.status}`)
}
},
onmessage(msg) {
try {
const data = JSON.parse(msg.data)
if (data.type === 'error') {
ElMessage.error(data.body)
enableInput()
return
}
if (data.type === 'end') {
enableInput()
lineBuffer.value = '' // 清空缓冲
// 获取 token
const reply = chatData.value[chatData.value.length - 1]
httpPost('/api/chat/tokens', {
text: '',
model: getModelValue(modelID.value),
chat_id: chatId.value,
})
.then((res) => {
reply['created_at'] = new Date().getTime()
reply['tokens'] = res.data
// 将聊天框的滚动条滑动到最底部
nextTick(() => {
document
.getElementById('chat-box')
.scrollTo(0, document.getElementById('chat-box').scrollHeight)
})
})
.catch(() => {})
isNewMsg.value = true
return
}
if (data.type === 'text') {
if (isNewMsg.value) {
isNewMsg.value = false
lineBuffer.value = data.body
const reply = chatData.value[chatData.value.length - 1]
if (reply) {
reply['content'] = lineBuffer.value
}
} else {
lineBuffer.value += data.body
const reply = chatData.value[chatData.value.length - 1]
if (reply) {
reply['content'] = lineBuffer.value
}
}
}
// 将聊天框的滚动条滑动到最底部 // 将聊天框的滚动条滑动到最底部
nextTick(() => { nextTick(() => {
document document
.getElementById('chat-box') .getElementById('chat-box')
.scrollTo(0, document.getElementById('chat-box').scrollHeight) .scrollTo(0, document.getElementById('chat-box').scrollHeight)
localStorage.setItem('chat_id', chatId.value)
}) })
}) } catch (error) {
.catch(() => {}) console.error('Error processing message:', error)
isNewMsg.value = true enableInput()
} else if (data.type === 'text') { ElMessage.error('消息处理出错,请重试')
lineBuffer.value += data.body }
const reply = chatData.value[chatData.value.length - 1] },
if (reply) { onerror(err) {
reply['content'] = lineBuffer.value console.error('SSE Error:', err)
} enableInput()
} ElMessage.error('连接已断开,请重试')
// 将聊天框的滚动条滑动到最底部 },
nextTick(() => { onclose() {
document console.log('SSE connection closed')
.getElementById('chat-box') enableInput()
.scrollTo(0, document.getElementById('chat-box').scrollHeight) },
localStorage.setItem('chat_id', chatId.value)
}) })
}) } catch (error) {
console.error('Failed to send message:', error)
enableInput()
ElMessage.error('发送消息失败,请重试')
}
// 初始化模型分类和分组 tmpChatTitle.value = content
updateModelCategories() prompt.value = ''
updateGroupedModels() files.value = []
}) row.value = 1
return true
onUnmounted(() => {
store.removeMessageHandler('chat')
})
// 初始化数据
const initData = () => {
// 加载模型
httpGet('/api/model/list?type=chat')
.then((res) => {
models.value = res.data
if (!modelID.value) {
modelID.value = models.value[0].id
}
// 加载角色列表
httpGet(`/api/app/list/user`, { id: roleId.value })
.then((res) => {
roles.value = res.data
if (!roleId.value) {
roleId.value = roles.value[0]['id']
}
// 如果登录状态就创建对话连接
checkSession()
.then((user) => {
loginUser.value = user
isLogin.value = true
newChat()
})
.catch(() => {})
})
.catch((e) => {
ElMessage.error('获取聊天角色失败: ' + e.messages)
})
})
.catch((e) => {
ElMessage.error('加载模型失败: ' + e.message)
})
// 获取会话列表
httpGet('/api/chat/list')
.then((res) => {
if (res.data) {
chatList.value = res.data
allChats.value = res.data
}
})
.catch(() => {
ElMessage.error('加载会话列表失败!')
})
// 允许在输入框粘贴文件
inputRef.value.addEventListener('paste', (event) => {
const items = (event.clipboardData || window.clipboardData).items
for (let item of items) {
if (item.kind === 'file') {
const file = item.getAsFile()
const formData = new FormData()
formData.append('file', file)
loading.value = true
// 执行上传操作
httpPost('/api/upload', formData)
.then((res) => {
files.value.push(res.data)
ElMessage.success({ message: '上传成功', duration: 500 })
loading.value = false
})
.catch((e) => {
ElMessage.error('文件上传失败:' + e.message)
loading.value = false
})
break
}
}
})
} }
const getRoleById = function (rid) { const getRoleById = function (rid) {
@@ -859,7 +920,6 @@ const getRoleById = function (rid) {
const resizeElement = function () { const resizeElement = function () {
chatListHeight.value = window.innerHeight - 240 chatListHeight.value = window.innerHeight - 240
// chatBoxHeight.value = window.innerHeight;
mainWinHeight.value = window.innerHeight - 50 mainWinHeight.value = window.innerHeight - 50
chatBoxHeight.value = window.innerHeight - 101 - 82 - 38 chatBoxHeight.value = window.innerHeight - 101 - 82 - 38
} }
@@ -1041,85 +1101,6 @@ const autofillPrompt = (text) => {
inputRef.value.focus() inputRef.value.focus()
sendMessage() sendMessage()
} }
// 发送消息
const sendMessage = function () {
if (!isLogin.value) {
console.log('未登录')
store.setShowLoginDialog(true)
return
}
if (store.socket.conn.readyState !== WebSocket.OPEN) {
ElMessage.warning('连接断开,正在重连...')
return
}
if (canSend.value === false) {
ElMessage.warning('AI 正在作答中,请稍后...')
return
}
if (prompt.value.trim().length === 0 || canSend.value === false) {
showMessageError('请输入要发送的消息!')
return false
}
// 如果携带了文件,则串上文件地址
let content = prompt.value
if (files.value.length > 0) {
content += files.value.map((file) => file.url).join(' ')
}
// else if (files.value.length > 1) {
// showMessageError("当前只支持上传一个文件!");
// return false;
// }
// 追加消息
chatData.value.push({
type: 'prompt',
id: randString(32),
icon: loginUser.value.avatar,
content: content,
model: getModelValue(modelID.value),
created_at: new Date().getTime() / 1000,
})
// 添加空回复消息
const _role = getRoleById(roleId.value)
chatData.value.push({
chat_id: chatId,
role_id: roleId.value,
type: 'reply',
id: randString(32),
icon: _role['icon'],
content: '',
})
nextTick(() => {
document
.getElementById('chat-box')
.scrollTo(0, document.getElementById('chat-box').scrollHeight)
})
showHello.value = false
disableInput(false)
store.socket.conn.send(
JSON.stringify({
channel: 'chat',
type: 'text',
body: {
role_id: roleId.value,
model_id: modelID.value,
chat_id: chatId.value,
content: content,
tools: toolSelected.value,
stream: stream.value,
},
})
)
tmpChatTitle.value = content
prompt.value = ''
files.value = []
row.value = 1
return true
}
const clearAllChats = function () { const clearAllChats = function () {
ElMessageBox.confirm('清除所有对话?此操作不可撤销!', '警告', { ElMessageBox.confirm('清除所有对话?此操作不可撤销!', '警告', {
@@ -1186,6 +1167,7 @@ const loadChatHistory = function (chatId) {
}) })
} }
// 停止生成
const stopGenerate = function () { const stopGenerate = function () {
showStopGenerate.value = false showStopGenerate.value = false
httpGet('/api/chat/stop?session_id=' + getClientId()).then(() => { httpGet('/api/chat/stop?session_id=' + getClientId()).then(() => {