refactor AI chat message struct, allow users to set whether the AI responds in stream, compatible with the GPT-o1 model

This commit is contained in:
RockYang 2024-09-14 17:06:13 +08:00
parent 131efd6ba5
commit 0385e60ce1
18 changed files with 245 additions and 245 deletions

View File

@ -205,6 +205,7 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/chat/detail" || c.Request.URL.Path == "/api/chat/detail" ||
c.Request.URL.Path == "/api/chat/list" || c.Request.URL.Path == "/api/chat/list" ||
c.Request.URL.Path == "/api/app/list" || c.Request.URL.Path == "/api/app/list" ||
c.Request.URL.Path == "/api/app/type/list" ||
c.Request.URL.Path == "/api/app/list/user" || c.Request.URL.Path == "/api/app/list/user" ||
c.Request.URL.Path == "/api/model/list" || c.Request.URL.Path == "/api/model/list" ||
c.Request.URL.Path == "/api/mj/imgWall" || c.Request.URL.Path == "/api/mj/imgWall" ||

View File

@ -57,6 +57,7 @@ type ChatSession struct {
ClientIP string `json:"client_ip"` // 客户端 IP ClientIP string `json:"client_ip"` // 客户端 IP
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
Model ChatModel `json:"model"` // GPT 模型 Model ChatModel `json:"model"` // GPT 模型
Start int64 `json:"start"` // 开始请求时间戳
Tools []int `json:"tools"` // 工具函数列表 Tools []int `json:"tools"` // 工具函数列表
Stream bool `json:"stream"` // 是否采用流式输出 Stream bool `json:"stream"` // 是否采用流式输出
} }

View File

@ -26,10 +26,9 @@ type ReplyMessage struct {
type WsMsgType string type WsMsgType string
const ( const (
WsStart = WsMsgType("start") WsContent = WsMsgType("content") // 输出内容
WsMiddle = WsMsgType("middle") WsEnd = WsMsgType("end")
WsEnd = WsMsgType("end") WsErr = WsMsgType("error")
WsErr = WsMsgType("error")
) )
// InputMessage 对话输入消息结构 // InputMessage 对话输入消息结构

View File

@ -6,6 +6,7 @@ import (
"geekai/store/vo" "geekai/store/vo"
"geekai/utils" "geekai/utils"
"geekai/utils/resp" "geekai/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -22,7 +23,7 @@ func NewChatAppTypeHandler(app *core.AppServer, db *gorm.DB) *ChatAppTypeHandler
func (h *ChatAppTypeHandler) List(c *gin.Context) { func (h *ChatAppTypeHandler) List(c *gin.Context) {
var items []model.AppType var items []model.AppType
var appTypes = make([]vo.AppType, 0) var appTypes = make([]vo.AppType, 0)
err := h.DB.Order("sort_num ASC").Find(&items).Error err := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items).Error
if err != nil { if err != nil {
resp.ERROR(c, err.Error()) resp.ERROR(c, err.Error())
return return

View File

@ -202,15 +202,16 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
var req = types.ApiRequest{ var req = types.ApiRequest{
Model: session.Model.Value, Model: session.Model.Value,
Temperature: session.Model.Temperature,
} }
// 兼容 GPT-O1 模型 // 兼容 GPT-O1 模型
if strings.HasPrefix(session.Model.Value, "o1-") { if strings.HasPrefix(session.Model.Value, "o1-") {
req.MaxCompletionTokens = session.Model.MaxTokens utils.ReplyContent(ws, "AI 正在思考...\n")
req.Stream = false req.Stream = false
session.Start = time.Now().Unix()
} else { } else {
req.MaxTokens = session.Model.MaxTokens req.MaxTokens = session.Model.MaxTokens
req.Temperature = session.Model.Temperature
req.Stream = session.Stream req.Stream = session.Stream
} }
@ -449,7 +450,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
if err != nil { if err != nil {
return nil, err return nil, err
} }
logger.Debugf(utils.JsonEncode(req)) logger.Debugf("对话请求消息体:%+v", req)
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象 // 创建 HttpClient 请求对象
@ -499,14 +500,6 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
} }
} }
type Usage struct {
Prompt string
Content string
PromptTokens int
CompletionTokens int
TotalTokens int
}
func (h *ChatHandler) saveChatHistory( func (h *ChatHandler) saveChatHistory(
req types.ApiRequest, req types.ApiRequest,
usage Usage, usage Usage,
@ -517,12 +510,8 @@ func (h *ChatHandler) saveChatHistory(
userVo vo.User, userVo vo.User,
promptCreatedAt time.Time, promptCreatedAt time.Time,
replyCreatedAt time.Time) { replyCreatedAt time.Time) {
if message.Role == "" {
message.Role = "assistant"
}
message.Content = usage.Content
useMsg := types.Message{Role: "user", Content: usage.Prompt}
useMsg := types.Message{Role: "user", Content: usage.Prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文 // 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.SysConfig.EnableContext { if h.App.SysConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
@ -573,7 +562,7 @@ func (h *ChatHandler) saveChatHistory(
RoleId: role.Id, RoleId: role.Id,
Type: types.ReplyMsg, Type: types.ReplyMsg,
Icon: role.Icon, Icon: role.Icon,
Content: message.Content, Content: usage.Content,
Tokens: replyTokens, Tokens: replyTokens,
TotalTokens: totalTokens, TotalTokens: totalTokens,
UseContext: true, UseContext: true,

View File

@ -23,7 +23,15 @@ import (
"time" "time"
) )
type respVo struct { type Usage struct {
Prompt string `json:"prompt,omitempty"`
Content string `json:"content,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIResVo struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created int `json:"created"` Created int `json:"created"`
@ -38,11 +46,7 @@ type respVo struct {
Logprobs interface{} `json:"logprobs"` Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} `json:"choices"` } `json:"choices"`
Usage struct { Usage Usage `json:"usage"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
} }
// OPenAI 消息发送实现 // OPenAI 消息发送实现
@ -73,19 +77,19 @@ func (h *ChatHandler) sendOpenAiMessage(
if response.StatusCode != 200 { if response.StatusCode != 200 {
body, _ := io.ReadAll(response.Body) body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, body) return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body))
} }
contentType := response.Header.Get("Content-Type") contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") { if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间 replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息 // 循环读取 Chunk 消息
var message = types.Message{} var message = types.Message{Role: "assistant"}
var contents = make([]string, 0) var contents = make([]string, 0)
var function model.Function var function model.Function
var toolCall = false var toolCall = false
var arguments = make([]string, 0) var arguments = make([]string, 0)
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 { if !strings.Contains(line, "data:") || len(line) < 30 {
@ -132,8 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if res.Error == nil { if res.Error == nil {
toolCall = true toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg})
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg) contents = append(contents, callMsg)
} }
continue continue
@ -150,12 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
} else { } else {
content := responseBody.Choices[0].Delta.Content content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content)) contents = append(contents, utils.InterfaceToString(content))
if isNew {
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(ws, types.ReplyMessage{ utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsMiddle, Type: types.WsContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
}) })
} }
@ -188,13 +187,13 @@ func (h *ChatHandler) sendOpenAiMessage(
if errMsg != "" || apiRes.Code != types.Success { if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.ReplyMessage{ utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsMiddle, Type: types.WsContent,
Content: msg, Content: msg,
}) })
contents = append(contents, msg) contents = append(contents, msg)
} else { } else {
utils.ReplyChunkMessage(ws, types.ReplyMessage{ utils.ReplyChunkMessage(ws, types.ReplyMessage{
Type: types.WsMiddle, Type: types.WsContent,
Content: apiRes.Data, Content: apiRes.Data,
}) })
contents = append(contents, utils.InterfaceToString(apiRes.Data)) contents = append(contents, utils.InterfaceToString(apiRes.Data))
@ -210,10 +209,27 @@ func (h *ChatHandler) sendOpenAiMessage(
CompletionTokens: 0, CompletionTokens: 0,
TotalTokens: 0, TotalTokens: 0,
} }
message.Content = usage.Content
h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) h.saveChatHistory(req, usage, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
} }
} else { // 非流式输出 } else { // 非流式输出
var respVo OpenAIResVo
body, err := io.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("读取响应失败:%v", body)
}
err = json.Unmarshal(body, &respVo)
if err != nil {
return fmt.Errorf("解析响应失败:%v", body)
}
content := respVo.Choices[0].Message.Content
if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
}
utils.ReplyMessage(ws, content)
respVo.Usage.Prompt = prompt
respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
} }
return nil return nil

View File

@ -86,6 +86,8 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
utils.ReplyErrorMessage(client, err.Error()) utils.ReplyErrorMessage(client, err.Error())
} else {
utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd})
} }
} }
@ -148,7 +150,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
if strings.Contains(contentType, "text/event-stream") { if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息 // 循环读取 Chunk 消息
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 { if !strings.Contains(line, "data:") || len(line) < 30 {
@ -169,12 +170,8 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
break break
} }
if isNew {
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.ReplyMessage{ utils.ReplyChunkMessage(client, types.ReplyMessage{
Type: types.WsMiddle, Type: types.WsContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
}) })
} // end for } // end for

View File

@ -512,6 +512,11 @@ func main() {
group.POST("enable", h.Enable) group.POST("enable", h.Enable)
group.POST("sort", h.Sort) group.POST("sort", h.Sort)
}), }),
fx.Provide(handler.NewChatAppTypeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.ChatAppTypeHandler) {
group := s.Engine.Group("/api/app/type")
group.GET("list", h.List)
}),
fx.Provide(handler.NewTestHandler), fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) { fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
group := s.Engine.Group("/api/test") group := s.Engine.Group("/api/test")

View File

@ -33,11 +33,14 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
// ReplyMessage 回复客户端一条完整的消息 // ReplyMessage 回复客户端一条完整的消息
func ReplyMessage(ws *types.WsClient, message interface{}) { func ReplyMessage(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsStart}) ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsMiddle, Content: message})
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd}) ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
} }
func ReplyContent(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
}
// ReplyErrorMessage 向客户端发送错误消息 // ReplyErrorMessage 向客户端发送错误消息
func ReplyErrorMessage(ws *types.WsClient, message interface{}) { func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message}) ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})

View File

@ -132,12 +132,13 @@ const content =ref(processPrompt(props.data.content))
const files = ref([]) const files = ref([])
onMounted(() => { onMounted(() => {
// if (!finalTokens.value) { processFiles()
// httpPost("/api/chat/tokens", {text: props.data.content, model: props.data.model}).then(res => { })
// finalTokens.value = res.data;
// }).catch(() => { const processFiles = () => {
// }) if (!props.data.content) {
// } return
}
const linkRegex = /(https?:\/\/\S+)/g; const linkRegex = /(https?:\/\/\S+)/g;
const links = props.data.content.match(linkRegex); const links = props.data.content.match(linkRegex);
@ -159,8 +160,7 @@ onMounted(() => {
} }
content.value = md.render(content.value.trim()) content.value = md.render(content.value.trim())
}) }
const isExternalImg = (link, files) => { const isExternalImg = (link, files) => {
return isImage(link) && !files.find(file => file.url === link) return isImage(link) && !files.find(file => file.url === link)
} }

View File

@ -15,7 +15,9 @@
<el-radio value="chat">对话样式</el-radio> <el-radio value="chat">对话样式</el-radio>
</el-radio-group> </el-radio-group>
</el-form-item> </el-form-item>
<el-form-item label="流式输出:">
<el-switch v-model="data.stream" @change="(val) => {store.setChatStream(val)}" />
</el-form-item>
</el-form> </el-form>
</div> </div>
</el-dialog> </el-dialog>
@ -28,6 +30,7 @@ const store = useSharedStore();
const data = ref({ const data = ref({
style: store.chatListStyle, style: store.chatListStyle,
stream: store.chatStream,
}) })
// eslint-disable-next-line no-undef // eslint-disable-next-line no-undef
const props = defineProps({ const props = defineProps({

View File

@ -4,7 +4,8 @@ import Storage from 'good-storage'
export const useSharedStore = defineStore('shared', { export const useSharedStore = defineStore('shared', {
state: () => ({ state: () => ({
showLoginDialog: false, showLoginDialog: false,
chatListStyle: Storage.get("chat_list_style","chat") chatListStyle: Storage.get("chat_list_style","chat"),
chatStream: Storage.get("chat_stream",true),
}), }),
getters: {}, getters: {},
actions: { actions: {
@ -14,6 +15,10 @@ export const useSharedStore = defineStore('shared', {
setChatListStyle(value) { setChatListStyle(value) {
this.chatListStyle = value; this.chatListStyle = value;
Storage.set("chat_list_style", value); Storage.set("chat_list_style", value);
},
setChatStream(value) {
this.chatStream = value;
Storage.set("chat_stream", value);
} }
} }
}); });

View File

@ -9,8 +9,6 @@
* Util lib functions * Util lib functions
*/ */
import {showConfirmDialog} from "vant"; import {showConfirmDialog} from "vant";
import {httpDownload} from "@/utils/http";
import {showMessageError} from "@/utils/dialog";
// generate a random string // generate a random string
export function randString(length) { export function randString(length) {
@ -183,6 +181,10 @@ export function isImage(url) {
} }
export function processContent(content) { export function processContent(content) {
if (!content) {
return ""
}
// 如果是图片链接地址,则直接替换成图片标签 // 如果是图片链接地址,则直接替换成图片标签
const linkRegex = /(https?:\/\/\S+)/g; const linkRegex = /(https?:\/\/\S+)/g;
const links = content.match(linkRegex); const links = content.match(linkRegex);

View File

@ -106,7 +106,7 @@
<el-dropdown-menu class="tools-dropdown"> <el-dropdown-menu class="tools-dropdown">
<el-checkbox-group v-model="toolSelected"> <el-checkbox-group v-model="toolSelected">
<el-dropdown-item v-for="item in tools" :key="item.id"> <el-dropdown-item v-for="item in tools" :key="item.id">
<el-checkbox :value="item.id" :label="item.label" @change="changeTool" /> <el-checkbox :value="item.id" :label="item.label" />
<el-tooltip :content="item.description" placement="right"> <el-tooltip :content="item.description" placement="right">
<el-icon><InfoFilled /></el-icon> <el-icon><InfoFilled /></el-icon>
</el-tooltip> </el-tooltip>
@ -271,6 +271,12 @@ watch(() => store.chatListStyle, (newValue) => {
const tools = ref([]) const tools = ref([])
const toolSelected = ref([]) const toolSelected = ref([])
const loadHistory = ref(false) const loadHistory = ref(false)
const stream = ref(store.chatStream)
watch(() => store.chatStream, (newValue) => {
stream.value = newValue
});
// ID // ID
if (router.currentRoute.value.query.role_id) { if (router.currentRoute.value.query.role_id) {
@ -491,16 +497,6 @@ const newChat = () => {
connect() connect()
} }
//
const changeTool = () => {
if (!isLogin.value) {
return;
}
loadHistory.value = false
socket.value.close()
}
// //
const loadChat = function (chat) { const loadChat = function (chat) {
if (!isLogin.value) { if (!isLogin.value) {
@ -598,6 +594,7 @@ const lineBuffer = ref(''); // 输出缓冲行
const socket = ref(null); const socket = ref(null);
const canSend = ref(true); const canSend = ref(true);
const sessionId = ref("") const sessionId = ref("")
const isNewMsg = ref(true)
const connect = function () { const connect = function () {
const chatRole = getRoleById(roleId.value); const chatRole = getRoleById(roleId.value);
// WebSocket // WebSocket
@ -612,8 +609,7 @@ const connect = function () {
} }
loading.value = true loading.value = true
const toolIds = toolSelected.value.join(',') const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}`);
const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}&tools=${toolIds}`);
_socket.addEventListener('open', () => { _socket.addEventListener('open', () => {
enableInput() enableInput()
if (loadHistory.value) { if (loadHistory.value) {
@ -629,15 +625,22 @@ const connect = function () {
reader.readAsText(event.data, "UTF-8"); reader.readAsText(event.data, "UTF-8");
reader.onload = () => { reader.onload = () => {
const data = JSON.parse(String(reader.result)); const data = JSON.parse(String(reader.result));
if (data.type === 'start') { if (data.type === 'error') {
ElMessage.error(data.message)
return
}
if (isNewMsg.value && data.type !== 'end') {
const prePrompt = chatData.value[chatData.value.length-1]?.content const prePrompt = chatData.value[chatData.value.length-1]?.content
chatData.value.push({ chatData.value.push({
type: "reply", type: "reply",
id: randString(32), id: randString(32),
icon: chatRole['icon'], icon: chatRole['icon'],
prompt:prePrompt, prompt:prePrompt,
content: "", content: data.content,
}); });
isNewMsg.value = false
lineBuffer.value = data.content;
} else if (data.type === 'end') { // } else if (data.type === 'end') { //
// //
if (newChatItem.value !== null) { if (newChatItem.value !== null) {
@ -663,6 +666,7 @@ const connect = function () {
nextTick(() => { nextTick(() => {
document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight) document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight)
}) })
isNewMsg.value = true
}).catch(() => { }).catch(() => {
}) })
@ -688,6 +692,7 @@ const connect = function () {
_socket.addEventListener('close', () => { _socket.addEventListener('close', () => {
disableInput(true) disableInput(true)
loadHistory.value = false
connect() connect()
}); });
@ -775,7 +780,7 @@ const sendMessage = function () {
showHello.value = false showHello.value = false
disableInput(false) disableInput(false)
socket.value.send(JSON.stringify({type: "chat", content: content})); socket.value.send(JSON.stringify({tools: toolSelected.value, content: content, stream: stream.value}));
tmpChatTitle.value = content tmpChatTitle.value = content
prompt.value = '' prompt.value = ''
files.value = [] files.value = []
@ -813,7 +818,7 @@ const loadChatHistory = function (chatId) {
chatData.value = [] chatData.value = []
httpGet('/api/chat/history?chat_id=' + chatId).then(res => { httpGet('/api/chat/history?chat_id=' + chatId).then(res => {
const data = res.data const data = res.data
if (!data || data.length === 0) { // if ((!data || data.length === 0) && chatData.value.length === 0) { //
const _role = getRoleById(roleId.value) const _role = getRoleById(roleId.value)
chatData.value.push({ chatData.value.push({
chat_id: chatId, chat_id: chatId,
@ -852,7 +857,7 @@ const stopGenerate = function () {
// //
const reGenerate = function (prompt) { const reGenerate = function (prompt) {
disableInput(false) disableInput(false)
const text = '重新生成下面问题的答案' + prompt; const text = '重新回答下述问题' + prompt;
// //
chatData.value.push({ chatData.value.push({
type: "prompt", type: "prompt",
@ -860,7 +865,7 @@ const reGenerate = function (prompt) {
icon: loginUser.value.avatar, icon: loginUser.value.avatar,
content: text content: text
}); });
socket.value.send(JSON.stringify({type: "chat", content: prompt})); socket.value.send(JSON.stringify({tools: toolSelected.value, content: text, stream: stream.value}));
} }
const chatName = ref('') const chatName = ref('')

View File

@ -231,10 +231,7 @@ const connect = (userId) => {
reader.onload = () => { reader.onload = () => {
const data = JSON.parse(String(reader.result)) const data = JSON.parse(String(reader.result))
switch (data.type) { switch (data.type) {
case "start": case "content":
text.value = ""
break
case "middle":
text.value += data.content text.value += data.content
html.value = md.render(processContent(text.value)) html.value = md.render(processContent(text.value))
break break

View File

@ -2,46 +2,38 @@
<div class="admin-login"> <div class="admin-login">
<div class="main"> <div class="main">
<div class="contain"> <div class="contain">
<div class="logo"> <div class="logo" @click="router.push('/')">
<el-image :src="logo" fit="cover" @click="router.push('/')"/> <el-image :src="logo" fit="cover"/>
</div> </div>
<div class="header">{{ title }}</div> <h1 class="header">{{ title }}</h1>
<div class="content"> <div class="content">
<div class="block"> <el-input v-model="username" placeholder="请输入用户名" size="large"
<el-input placeholder="请输入用户名" size="large" v-model="username" autocomplete="off" autofocus autocomplete="off" autofocus @keyup.enter="login">
@keyup="keyupHandle"> <template #prefix>
<template #prefix> <el-icon>
<el-icon> <UserFilled/>
<UserFilled/> </el-icon>
</el-icon> </template>
</template> </el-input>
</el-input>
</div>
<div class="block"> <el-input v-model="password" placeholder="请输入密码" size="large"
<el-input placeholder="请输入密码" size="large" v-model="password" show-password autocomplete="off" show-password autocomplete="off" @keyup.enter="login">
@keyup="keyupHandle"> <template #prefix>
<template #prefix> <el-icon>
<el-icon> <Lock/>
<Lock/> </el-icon>
</el-icon> </template>
</template> </el-input>
</el-input>
</div>
<el-row class="btn-row"> <el-row class="btn-row">
<el-button class="login-btn" size="large" type="primary" @click="login">登录</el-button> <el-button class="login-btn" size="large" type="primary" @click="login">登录</el-button>
</el-row> </el-row>
</div> </div>
</div> </div>
<captcha v-if="enableVerify" @success="doLogin" ref="captchaRef"/> <captcha v-if="enableVerify" @success="doLogin" ref="captchaRef"/>
<footer-bar class="footer"/>
<footer class="footer">
<footer-bar/>
</footer>
</div> </div>
</div> </div>
</template> </template>
@ -80,12 +72,6 @@ getSystemInfo().then(res => {
ElMessage.error("加载系统配置失败: " + e.message) ElMessage.error("加载系统配置失败: " + e.message)
}) })
const keyupHandle = (e) => {
if (e.key === 'Enter') {
login();
}
}
const login = function () { const login = function () {
if (username.value === '') { if (username.value === '') {
return ElMessage.error('请输入用户名'); return ElMessage.error('请输入用户名');

View File

@ -225,7 +225,7 @@ const newChat = (item) => {
} }
showPicker.value = false showPicker.value = false
const options = item.selectedOptions const options = item.selectedOptions
router.push(`/mobile/chat/session?title=新对话&role_id=${options[0].value}&model_id=${options[1].value}&chat_id=0}`) router.push(`/mobile/chat/session?title=新对话&role_id=${options[0].value}&model_id=${options[1].value}`)
} }
const changeChat = (chat) => { const changeChat = (chat) => {

View File

@ -123,7 +123,7 @@
</template> </template>
<script setup> <script setup>
import {nextTick, onMounted, onUnmounted, ref} from "vue"; import {nextTick, onMounted, onUnmounted, ref, watch} from "vue";
import {showImagePreview, showNotify, showToast} from "vant"; import {showImagePreview, showNotify, showToast} from "vant";
import {onBeforeRouteLeave, useRouter} from "vue-router"; import {onBeforeRouteLeave, useRouter} from "vue-router";
import {processContent, randString, renderInputText, UUID} from "@/utils/libs"; import {processContent, randString, renderInputText, UUID} from "@/utils/libs";
@ -135,7 +135,8 @@ import ChatReply from "@/components/mobile/ChatReply.vue";
import {getSessionId, getUserToken} from "@/store/session"; import {getSessionId, getUserToken} from "@/store/session";
import {checkSession} from "@/store/cache"; import {checkSession} from "@/store/cache";
import Clipboard from "clipboard"; import Clipboard from "clipboard";
import {showLoginDialog} from "@/utils/dialog"; import { showMessageError} from "@/utils/dialog";
import {useSharedStore} from "@/store/sharedata";
const winHeight = ref(0) const winHeight = ref(0)
const navBarRef = ref(null) const navBarRef = ref(null)
@ -167,49 +168,51 @@ if (chatId.value) {
title.value = res.data.title title.value = res.data.title
modelId.value = res.data.model_id modelId.value = res.data.model_id
roleId.value = res.data.role_id roleId.value = res.data.role_id
loadModels()
}).catch(() => { }).catch(() => {
loadModels()
}) })
} else { } else {
title.value = "新建对话" title.value = "新建对话"
chatId.value = UUID()
} }
// const loadModels = () => {
httpGet('/api/model/list').then(res => { //
models.value = res.data httpGet('/api/model/list').then(res => {
if (!modelId.value) { models.value = res.data
modelId.value = models.value[0].id if (!modelId.value) {
} modelId.value = models.value[0].id
for (let i = 0; i < models.value.length; i++) {
models.value[i].text = models.value[i].name
models.value[i].mValue = models.value[i].value
models.value[i].value = models.value[i].id
}
modelValue.value = getModelName(modelId.value)
//
httpGet(`/api/app/list/user`).then((res) => {
roles.value = res.data;
if (!roleId.value) {
roleId.value = roles.value[0]['id']
} }
// build data for role picker for (let i = 0; i < models.value.length; i++) {
for (let i = 0; i < roles.value.length; i++) { models.value[i].text = models.value[i].name
roles.value[i].text = roles.value[i].name models.value[i].mValue = models.value[i].value
roles.value[i].value = roles.value[i].id models.value[i].value = models.value[i].id
roles.value[i].helloMsg = roles.value[i].hello_msg
} }
modelValue.value = getModelName(modelId.value)
role.value = getRoleById(roleId.value) //
columns.value = [roles.value, models.value] httpGet(`/api/app/list/user`,{id: roleId.value}).then((res) => {
// roles.value = res.data;
if (!chatId.value) { if (!roleId.value) {
connect(chatId.value, roleId.value, modelId.value) roleId.value = roles.value[0]['id']
} }
}).catch((e) => { // build data for role picker
showNotify({type: "danger", message: '获取聊天角色失败: ' + e.messages}) for (let i = 0; i < roles.value.length; i++) {
roles.value[i].text = roles.value[i].name
roles.value[i].value = roles.value[i].id
roles.value[i].helloMsg = roles.value[i].hello_msg
}
role.value = getRoleById(roleId.value)
columns.value = [roles.value, models.value]
selectedValues.value = [roleId.value, modelId.value]
connect()
}).catch((e) => {
showNotify({type: "danger", message: '获取聊天角色失败: ' + e.messages})
})
}).catch(e => {
showNotify({type: "danger", message: "加载模型失败: " + e.message})
}) })
}).catch(e => { }
showNotify({type: "danger", message: "加载模型失败: " + e.message})
})
const url = ref(location.protocol + '//' + location.host + '/mobile/chat/export?chat_id=' + chatId.value) const url = ref(location.protocol + '//' + location.host + '/mobile/chat/export?chat_id=' + chatId.value)
@ -239,11 +242,12 @@ const newChat = (item) => {
roleId.value = options[0].value roleId.value = options[0].value
modelId.value = options[1].value modelId.value = options[1].value
modelValue.value = getModelName(modelId.value) modelValue.value = getModelName(modelId.value)
chatId.value = "" chatId.value = UUID()
chatData.value = [] chatData.value = []
role.value = getRoleById(roleId.value) role.value = getRoleById(roleId.value)
title.value = "新建对话" title.value = "新建对话"
connect(chatId.value, roleId.value, modelId.value) loadHistory.value = true
connect()
} }
const chatData = ref([]) const chatData = ref([])
@ -280,51 +284,60 @@ md.use(mathjaxPlugin)
const onLoad = () => { const onLoad = () => {
if (chatId.value) { // checkSession().then(() => {
checkSession().then(() => { // connect()
httpGet('/api/chat/history?chat_id=' + chatId.value).then(res => { // }).catch(() => {
// // })
finished.value = true; }
const data = res.data
if (data && data.length > 0) {
for (let i = 0; i < data.length; i++) {
if (data[i].type === "prompt") {
chatData.value.push(data[i]);
continue;
}
data[i].orgContent = data[i].content; const loadChatHistory = () => {
data[i].content = md.render(processContent(data[i].content)) httpGet('/api/chat/history?chat_id=' + chatId.value).then(res => {
chatData.value.push(data[i]); const role = getRoleById(roleId.value)
} //
finished.value = true;
nextTick(() => { const data = res.data
hl.configure({ignoreUnescapedHTML: true}) if (data.length === 0) {
const blocks = document.querySelector("#message-list-box").querySelectorAll('pre code'); chatData.value.push({
blocks.forEach((block) => { type: "reply",
hl.highlightElement(block) id: randString(32),
}) icon: role.icon,
content: role.hello_msg,
scrollListBox() orgContent: role.hello_msg,
})
}
connect(chatId.value, roleId.value, modelId.value);
}).catch(() => {
error.value = true
}) })
}).catch(() => { return
}
for (let i = 0; i < data.length; i++) {
if (data[i].type === "prompt") {
chatData.value.push(data[i]);
continue;
}
data[i].orgContent = data[i].content;
data[i].content = md.render(processContent(data[i].content))
chatData.value.push(data[i]);
}
nextTick(() => {
hl.configure({ignoreUnescapedHTML: true})
const blocks = document.querySelector("#message-list-box").querySelectorAll('pre code');
blocks.forEach((block) => {
hl.highlightElement(block)
})
scrollListBox()
}) })
}
}; }).catch(() => {
error.value = true
})
}
// websocket // websocket
onBeforeRouteLeave(() => { onBeforeRouteLeave(() => {
if (socket.value !== null) { if (socket.value !== null) {
activelyClose.value = true;
clearTimeout(heartbeatHandle.value)
socket.value.close(); socket.value.close();
} }
}) })
// socket // socket
@ -334,16 +347,15 @@ const showReGenerate = ref(false); // 重新生成
const previousText = ref(''); // const previousText = ref(''); //
const lineBuffer = ref(''); // const lineBuffer = ref(''); //
const socket = ref(null); const socket = ref(null);
const activelyClose = ref(false); // const canSend = ref(true)
const canSend = ref(true); const isNewMsg = ref(true)
const heartbeatHandle = ref(null) const loadHistory = ref(true)
const connect = function (chat_id, role_id, model_id) { const store = useSharedStore()
let isNewChat = false; const stream = ref(store.chatStream)
if (!chat_id) { watch(() => store.chatStream, (newValue) => {
isNewChat = true; stream.value = newValue
chat_id = UUID(); });
} const connect = function () {
// WebSocket // WebSocket
const _sessionId = getSessionId(); const _sessionId = getSessionId();
let host = process.env.VUE_APP_WS_HOST let host = process.env.VUE_APP_WS_HOST
@ -354,38 +366,15 @@ const connect = function (chat_id, role_id, model_id) {
host = 'ws://' + location.host; host = 'ws://' + location.host;
} }
} }
const _socket = new WebSocket(host + `/api/chat/new?session_id=${_sessionId}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelId.value}&token=${getUserToken()}`);
//
const sendHeartbeat = () => {
if (socket.value !== null) {
new Promise((resolve) => {
socket.value.send(JSON.stringify({type: "heartbeat", content: "ping"}))
resolve("success")
}).then(() => {
heartbeatHandle.value = setTimeout(() => sendHeartbeat(), 5000)
});
}
}
const _socket = new WebSocket(host + `/api/chat/new?session_id=${_sessionId}&role_id=${role_id}&chat_id=${chat_id}&model_id=${model_id}&token=${getUserToken()}`);
_socket.addEventListener('open', () => { _socket.addEventListener('open', () => {
loading.value = false loading.value = false
previousText.value = ''; previousText.value = '';
canSend.value = true; canSend.value = true;
activelyClose.value = false;
if (isNewChat) { // if (loadHistory.value) { //
chatData.value.push({ loadChatHistory()
type: "reply",
id: randString(32),
icon: role.value.icon,
content: role.value.hello_msg,
orgContent: role.value.hello_msg,
})
} }
//
sendHeartbeat()
}); });
_socket.addEventListener('message', event => { _socket.addEventListener('message', event => {
@ -394,20 +383,27 @@ const connect = function (chat_id, role_id, model_id) {
reader.readAsText(event.data, "UTF-8"); reader.readAsText(event.data, "UTF-8");
reader.onload = () => { reader.onload = () => {
const data = JSON.parse(String(reader.result)); const data = JSON.parse(String(reader.result));
if (data.type === 'start') { if (data.type === 'error') {
showMessageError(data.message)
return
}
if (isNewMsg.value && data.type !== 'end') {
chatData.value.push({ chatData.value.push({
type: "reply", type: "reply",
id: randString(32), id: randString(32),
icon: role.value.icon, icon: role.value.icon,
content: "" content: data.content
}); });
if (isNewChat) { if (!title.value) {
title.value = previousText.value title.value = previousText.value
} }
lineBuffer.value = data.content;
isNewMsg.value = false
} else if (data.type === 'end') { // } else if (data.type === 'end') { //
enableInput() enableInput()
lineBuffer.value = ''; // lineBuffer.value = ''; //
isNewMsg.value = true
} else { } else {
lineBuffer.value += data.content; lineBuffer.value += data.content;
const reply = chatData.value[chatData.value.length - 1] const reply = chatData.value[chatData.value.length - 1]
@ -443,17 +439,11 @@ const connect = function (chat_id, role_id, model_id) {
}); });
_socket.addEventListener('close', () => { _socket.addEventListener('close', () => {
if (activelyClose.value || socket.value === null) { //
return;
}
// //
canSend.value = true; canSend.value = true
loadHistory.value = false
// //
checkSession().then(() => { connect()
connect(chat_id, role_id, model_id)
}).catch(() => {
showLoginDialog(router)
});
}); });
socket.value = _socket; socket.value = _socket;
@ -501,7 +491,7 @@ const sendMessage = () => {
}) })
disableInput(false) disableInput(false)
socket.value.send(JSON.stringify({type: "chat", content: prompt.value})); socket.value.send(JSON.stringify({stream: stream.value, content: prompt.value}));
previousText.value = prompt.value; previousText.value = prompt.value;
prompt.value = ''; prompt.value = '';
return true; return true;
@ -524,7 +514,7 @@ const reGenerate = () => {
icon: loginUser.value.avatar, icon: loginUser.value.avatar,
content: renderInputText(text) content: renderInputText(text)
}); });
socket.value.send(JSON.stringify({type: "chat", content: previousText.value})); socket.value.send(JSON.stringify({stream: stream.value, content: previousText.value}));
} }
const showShare = ref(false) const showShare = ref(false)