From 41e4b1c7ac3d4ec0babc91ea72324f52cce3e8fa Mon Sep 17 00:00:00 2001 From: GeekMaster Date: Mon, 26 May 2025 18:26:36 +0800 Subject: [PATCH] =?UTF-8?q?SSE=20=E6=9B=BF=E6=8D=A2=20websocket?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/handler/chat_handler.go | 325 ++++++++++++++++++- api/handler/chat_openai_handler.go | 480 +++++++++++++++-------------- api/handler/ws_handler.go | 153 --------- api/main.go | 5 +- web/package.json | 2 +- web/src/App.vue | 44 +-- web/src/store/sharedata.js | 110 +++---- web/src/views/ChatPlus.vue | 412 ++++++++++++------------- 8 files changed, 808 insertions(+), 723 deletions(-) delete mode 100644 api/handler/ws_handler.go diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 7261f14e..ea8ccc33 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -8,6 +8,7 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "bufio" "bytes" "context" "encoding/json" @@ -32,10 +33,30 @@ import ( "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" + req2 "github.com/imroc/req/v3" "github.com/sashabaranov/go-openai" "gorm.io/gorm" ) +const ( + ChatEventStart = "start" + ChatEventEnd = "end" + ChatEventError = "error" + ChatEventMessageDelta = "message_delta" + ChatEventTitle = "title" +) + +type ChatInput struct { + UserId uint `json:"user_id"` + RoleId int `json:"role_id"` + ModelId int `json:"model_id"` + ChatId string `json:"chat_id"` + Content string `json:"content"` + Tools []int `json:"tools"` + Stream bool `json:"stream"` + Files []vo.File `json:"files"` +} + type ChatHandler struct { BaseHandler redis *redis.Client @@ -58,7 +79,89 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag } } -func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { +// Chat 处理聊天请求 +func (h *ChatHandler) Chat(c *gin.Context) { + var data ChatInput + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // 设置SSE响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() + + // 验证聊天角色 + var chatRole model.ChatRole + err := h.DB.First(&chatRole, data.RoleId).Error + if err != nil || !chatRole.Enable { + pushMessage(c, ChatEventError, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!") + return + } + + // 如果角色绑定了模型ID,使用角色的模型ID + if chatRole.ModelId > 0 { + data.ModelId = int(chatRole.ModelId) + } + + // 获取模型信息 + var chatModel model.ChatModel + err = h.DB.Where("id", data.ModelId).First(&chatModel).Error + if err != nil || !chatModel.Enabled { + pushMessage(c, ChatEventError, "当前AI模型暂未启用,请更换模型后再发起对话!") + return + } + + session := &types.ChatSession{ + ClientIP: c.ClientIP(), + UserId: data.UserId, + ChatId: data.ChatId, + Tools: data.Tools, + Stream: data.Stream, + Model: types.ChatModel{ + KeyId: data.ModelId, + }, + } + + // 使用旧的聊天数据覆盖模型和角色ID + var chat model.ChatItem + h.DB.Where("chat_id", data.ChatId).First(&chat) + if chat.Id > 0 { + chatModel.Id = chat.ModelId + data.RoleId = int(chat.RoleId) + } + + // 复制模型数据 + err = utils.CopyObject(chatModel, &session.Model) + if err != nil { + logger.Error(err, chatModel) + } + session.Model.Id = chatModel.Id + + // 发送消息 + err = h.sendMessage(ctx, session, chatRole, data.Content, c) + if err != nil { + pushMessage(c, ChatEventError, err.Error()) + return + } + + pushMessage(c, ChatEventEnd, "对话完成") +} + +func pushMessage(c *gin.Context, msgType string, content interface{}) { + c.SSEvent("message", map[string]interface{}{ + "type": msgType, + "body": content, + }) + c.Writer.Flush() +} + +func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, c *gin.Context) error { var user model.User res := h.DB.Model(&model.User{}).First(&user, session.UserId) if res.Error != nil { @@ -254,7 +357,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio logger.Debugf("%+v", req.Messages) - return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, ws) + return h.sendOpenAiMessage(req, userVo, ctx, session, role, prompt, c) } // Tokens 统计 token 数量 @@ -584,3 +687,221 @@ func (h *ChatHandler) TextToSpeech(c *gin.Context) { // 直接写入完整的音频数据到响应 c.Writer.Write(audioBytes) } + +// OPenAI 消息发送实现 +func (h *ChatHandler) sendOpenAiMessage( + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + c *gin.Context) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Since(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + return fmt.Errorf("用户取消了请求:%s", prompt) + } else if strings.Contains(err.Error(), "no available key") { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + return err + } else { + defer response.Body.Close() + } + + if response.StatusCode != 200 { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body)) + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{Role: "assistant"} + var contents = make([]string, 0) + var function model.Function + var toolCall = false + var arguments = make([]string, 0) + var reasoning = false + + pushMessage(c, ChatEventStart, "开始响应") + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil { // 数据解析出错 + return errors.New(line) + } + if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 + continue + } + if responseBody.Choices[0].Delta.Content == nil && + responseBody.Choices[0].Delta.ToolCalls == nil && + responseBody.Choices[0].Delta.ReasoningContent == "" { + continue + } + + if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { + pushMessage(c, ChatEventError, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") + break + } + + var tool types.ToolCall + if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { + tool = responseBody.Choices[0].Delta.ToolCalls[0] + if toolCall && tool.Function.Name == "" { + arguments = append(arguments, tool.Function.Arguments) + continue + } + } + + // 兼容 Function Call + fun := responseBody.Choices[0].Delta.FunctionCall + if fun.Name != "" { + tool = *new(types.ToolCall) + tool.Function.Name = fun.Name + } else if toolCall { + arguments = append(arguments, fun.Arguments) + continue + } + + if !utils.IsEmptyValue(tool) { + res := h.DB.Where("name = ?", tool.Function.Name).First(&function) + if res.Error == nil { + toolCall = true + callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) + pushMessage(c, ChatEventMessageDelta, map[string]interface{}{ + "type": "text", + "content": callMsg, + }) + contents = append(contents, callMsg) + } + continue + } + + if responseBody.Choices[0].FinishReason == "tool_calls" || + responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 + break + } + + // output stopped + if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { // 正常输出结果 + // 兼容思考过程 + if responseBody.Choices[0].Delta.ReasoningContent != "" { + reasoningContent := responseBody.Choices[0].Delta.ReasoningContent + if !reasoning { + reasoningContent = fmt.Sprintf("%s", reasoningContent) + reasoning = true + } + + pushMessage(c, ChatEventMessageDelta, map[string]interface{}{ + "type": "text", + "content": reasoningContent, + }) + contents = append(contents, reasoningContent) + } else if responseBody.Choices[0].Delta.Content != "" { + finalContent := responseBody.Choices[0].Delta.Content + if reasoning { + finalContent = fmt.Sprintf("%s", responseBody.Choices[0].Delta.Content) + reasoning = false + } + contents = append(contents, utils.InterfaceToString(finalContent)) + pushMessage(c, ChatEventMessageDelta, map[string]interface{}{ + "type": "text", + "content": finalContent, + }) + } + } + } // end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + if toolCall { // 调用函数完成任务 + params := make(map[string]any) + _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) + logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) + params["user_id"] = userVo.Id + var apiRes types.BizVo + r, err := req2.C().R().SetHeader("Body-Type", "application/json"). + SetHeader("Authorization", function.Token). + SetBody(params).Post(function.Action) + errMsg := "" + if err != nil { + errMsg = err.Error() + } else { + all, _ := io.ReadAll(r.Body) + err = json.Unmarshal(all, &apiRes) + if err != nil { + errMsg = err.Error() + } else if apiRes.Code != types.Success { + errMsg = apiRes.Message + } + } + + if errMsg != "" { + errMsg = "调用函数工具出错:" + errMsg + contents = append(contents, errMsg) + } else { + errMsg = utils.InterfaceToString(apiRes.Data) + contents = append(contents, errMsg) + } + pushMessage(c, ChatEventMessageDelta, map[string]interface{}{ + "type": "text", + "content": errMsg, + }) + } + + // 消息发送成功 + if len(contents) > 0 { + usage := Usage{ + Prompt: prompt, + Content: strings.Join(contents, ""), + PromptTokens: 0, + CompletionTokens: 0, + TotalTokens: 0, + } + message.Content = usage.Content + h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + } else { + var respVo OpenAIResVo + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("读取响应失败:%v", body) + } + err = json.Unmarshal(body, &respVo) + if err != nil { + return fmt.Errorf("解析响应失败:%v", body) + } + content := respVo.Choices[0].Message.Content + if strings.HasPrefix(req.Model, "o1-") { + content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) + } + pushMessage(c, ChatEventMessageDelta, map[string]interface{}{ + "type": "text", + "content": content, + }) + respVo.Usage.Prompt = prompt + respVo.Usage.Content = content + h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now()) + } + + return nil +} diff --git a/api/handler/chat_openai_handler.go b/api/handler/chat_openai_handler.go index fea1a1e3..00c13093 100644 --- a/api/handler/chat_openai_handler.go +++ b/api/handler/chat_openai_handler.go @@ -1,253 +1,271 @@ package handler -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// // * Copyright 2023 The Geek-AI Authors. All rights reserved. +// // * Use of this source code is governed by a Apache-2.0 license +// // * that can be found in the LICENSE file. +// // * @Author yangjian102621@163.com +// // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -import ( - "bufio" - "context" - "encoding/json" - "errors" - "fmt" - "geekai/core/types" - "geekai/store/model" - "geekai/store/vo" - "geekai/utils" - "io" - "strings" - "time" +// import ( +// "bufio" +// "context" +// "encoding/json" +// "errors" +// "fmt" +// "geekai/core/types" +// "geekai/store/model" +// "geekai/store/vo" +// "geekai/utils" +// "io" +// "strings" +// "time" - req2 "github.com/imroc/req/v3" -) +// req2 "github.com/imroc/req/v3" +// ) -type Usage struct { - Prompt string `json:"prompt,omitempty"` - Content string `json:"content,omitempty"` - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} +// type Usage struct { +// Prompt string `json:"prompt,omitempty"` +// Content string `json:"content,omitempty"` +// PromptTokens int `json:"prompt_tokens"` +// CompletionTokens int `json:"completion_tokens"` +// TotalTokens int `json:"total_tokens"` +// } -type OpenAIResVo struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage Usage `json:"usage"` -} +// type OpenAIResVo struct { +// Id string `json:"id"` +// Object string `json:"object"` +// Created int `json:"created"` +// Model string `json:"model"` +// SystemFingerprint string `json:"system_fingerprint"` +// Choices []struct { +// Index int `json:"index"` +// Message struct { +// Role string `json:"role"` +// Content string `json:"content"` +// } `json:"message"` +// Logprobs interface{} `json:"logprobs"` +// FinishReason string `json:"finish_reason"` +// } `json:"choices"` +// Usage Usage `json:"usage"` +// } -// OPenAI 消息发送实现 -func (h *ChatHandler) sendOpenAiMessage( - req types.ApiRequest, - userVo vo.User, - ctx context.Context, - session *types.ChatSession, - role model.ChatRole, - prompt string, - ws *types.WsClient) error { - promptCreatedAt := time.Now() // 记录提问时间 - start := time.Now() - var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session, &apiKey) - logger.Info("HTTP请求完成,耗时:", time.Since(start)) - if err != nil { - if strings.Contains(err.Error(), "context canceled") { - return fmt.Errorf("用户取消了请求:%s", prompt) - } else if strings.Contains(err.Error(), "no available key") { - return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") - } - return err - } else { - defer response.Body.Close() - } +// // OPenAI 消息发送实现 +// func (h *ChatHandler) sendOpenAiMessage( +// req types.ApiRequest, +// userVo vo.User, +// ctx context.Context, +// session *types.ChatSession, +// role model.ChatRole, +// prompt string, +// messageChan chan interface{}) error { +// promptCreatedAt := time.Now() // 记录提问时间 +// start := time.Now() +// var apiKey = model.ApiKey{} +// response, err := h.doRequest(ctx, req, session, &apiKey) +// logger.Info("HTTP请求完成,耗时:", time.Since(start)) +// if err != nil { +// if strings.Contains(err.Error(), "context canceled") { +// return fmt.Errorf("用户取消了请求:%s", prompt) +// } else if strings.Contains(err.Error(), "no available key") { +// return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") +// } +// return err +// } else { +// defer response.Body.Close() +// } - if response.StatusCode != 200 { - body, _ := io.ReadAll(response.Body) - return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body)) - } +// if response.StatusCode != 200 { +// body, _ := io.ReadAll(response.Body) +// return fmt.Errorf("请求 OpenAI API 失败:%d, %v", response.StatusCode, string(body)) +// } - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{Role: "assistant"} - var contents = make([]string, 0) - var function model.Function - var toolCall = false - var arguments = make([]string, 0) - var reasoning = false +// contentType := response.Header.Get("Content-Type") +// if strings.Contains(contentType, "text/event-stream") { +// replyCreatedAt := time.Now() // 记录回复时间 +// // 循环读取 Chunk 消息 +// var message = types.Message{Role: "assistant"} +// var contents = make([]string, 0) +// var function model.Function +// var toolCall = false +// var arguments = make([]string, 0) +// var reasoning = false - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - line := scanner.Text() - if !strings.Contains(line, "data:") || len(line) < 30 { - continue - } - var responseBody = types.ApiResponse{} - err = json.Unmarshal([]byte(line[6:]), &responseBody) - if err != nil { // 数据解析出错 - return errors.New(line) - } - if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 - continue - } - if responseBody.Choices[0].Delta.Content == nil && - responseBody.Choices[0].Delta.ToolCalls == nil && - responseBody.Choices[0].Delta.ReasoningContent == "" { - continue - } +// scanner := bufio.NewScanner(response.Body) +// for scanner.Scan() { +// line := scanner.Text() +// if !strings.Contains(line, "data:") || len(line) < 30 { +// continue +// } +// var responseBody = types.ApiResponse{} +// err = json.Unmarshal([]byte(line[6:]), &responseBody) +// if err != nil { // 数据解析出错 +// return errors.New(line) +// } +// if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 +// continue +// } +// if responseBody.Choices[0].Delta.Content == nil && +// responseBody.Choices[0].Delta.ToolCalls == nil && +// responseBody.Choices[0].Delta.ReasoningContent == "" { +// continue +// } - if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { - utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") - break - } +// if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { +// messageChan <- map[string]interface{}{ +// "type": "text", +// "body": "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。", +// } +// break +// } - var tool types.ToolCall - if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { - tool = responseBody.Choices[0].Delta.ToolCalls[0] - if toolCall && tool.Function.Name == "" { - arguments = append(arguments, tool.Function.Arguments) - continue - } - } +// var tool types.ToolCall +// if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { +// tool = responseBody.Choices[0].Delta.ToolCalls[0] +// if toolCall && tool.Function.Name == "" { +// arguments = append(arguments, tool.Function.Arguments) +// continue +// } +// } - // 兼容 Function Call - fun := responseBody.Choices[0].Delta.FunctionCall - if fun.Name != "" { - tool = *new(types.ToolCall) - tool.Function.Name = fun.Name - } else if toolCall { - arguments = append(arguments, fun.Arguments) - continue - } +// // 兼容 Function Call +// fun := responseBody.Choices[0].Delta.FunctionCall +// if fun.Name != "" { +// tool = *new(types.ToolCall) +// tool.Function.Name = fun.Name +// } else if toolCall { +// arguments = append(arguments, fun.Arguments) +// continue +// } - if !utils.IsEmptyValue(tool) { - res := h.DB.Where("name = ?", tool.Function.Name).First(&function) - if res.Error == nil { - toolCall = true - callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) - utils.SendChunkMsg(ws, callMsg) - contents = append(contents, callMsg) - } - continue - } +// if !utils.IsEmptyValue(tool) { +// res := h.DB.Where("name = ?", tool.Function.Name).First(&function) +// if res.Error == nil { +// toolCall = true +// callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) +// messageChan <- map[string]interface{}{ +// "type": "text", +// "body": callMsg, +// } +// contents = append(contents, callMsg) +// } +// continue +// } - if responseBody.Choices[0].FinishReason == "tool_calls" || - responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 - break - } +// if responseBody.Choices[0].FinishReason == "tool_calls" || +// responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 +// break +// } - // output stopped - if responseBody.Choices[0].FinishReason != "" { - break // 输出完成或者输出中断了 - } else { // 正常输出结果 - // 兼容思考过程 - if responseBody.Choices[0].Delta.ReasoningContent != "" { - reasoningContent := responseBody.Choices[0].Delta.ReasoningContent - if !reasoning { - reasoningContent = fmt.Sprintf("%s", reasoningContent) - reasoning = true - } +// // output stopped +// if responseBody.Choices[0].FinishReason != "" { +// break // 输出完成或者输出中断了 +// } else { // 正常输出结果 +// // 兼容思考过程 +// if responseBody.Choices[0].Delta.ReasoningContent != "" { +// reasoningContent := responseBody.Choices[0].Delta.ReasoningContent +// if !reasoning { +// reasoningContent = fmt.Sprintf("%s", reasoningContent) +// reasoning = true +// } - utils.SendChunkMsg(ws, reasoningContent) - contents = append(contents, reasoningContent) - } else if responseBody.Choices[0].Delta.Content != "" { - finalContent := responseBody.Choices[0].Delta.Content - if reasoning { - finalContent = fmt.Sprintf("%s", responseBody.Choices[0].Delta.Content) - reasoning = false - } - contents = append(contents, utils.InterfaceToString(finalContent)) - utils.SendChunkMsg(ws, finalContent) - } - } - } // end for +// messageChan <- map[string]interface{}{ +// "type": "text", +// "body": reasoningContent, +// } +// contents = append(contents, reasoningContent) +// } else if responseBody.Choices[0].Delta.Content != "" { +// finalContent := responseBody.Choices[0].Delta.Content +// if reasoning { +// finalContent = fmt.Sprintf("%s", responseBody.Choices[0].Delta.Content) +// reasoning = false +// } +// contents = append(contents, utils.InterfaceToString(finalContent)) +// messageChan <- map[string]interface{}{ +// "type": "text", +// "body": finalContent, +// } +// } +// } +// } // end for - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error("信息读取出错:", err) - } - } +// if err := scanner.Err(); err != nil { +// if strings.Contains(err.Error(), "context canceled") { +// logger.Info("用户取消了请求:", prompt) +// } else { +// logger.Error("信息读取出错:", err) +// } +// } - if toolCall { // 调用函数完成任务 - params := make(map[string]any) - _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) - logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) - params["user_id"] = userVo.Id - var apiRes types.BizVo - r, err := req2.C().R().SetHeader("Body-Type", "application/json"). - SetHeader("Authorization", function.Token). - SetBody(params).Post(function.Action) - errMsg := "" - if err != nil { - errMsg = err.Error() - } else { - all, _ := io.ReadAll(r.Body) - err = json.Unmarshal(all, &apiRes) - if err != nil { - errMsg = err.Error() - } else if apiRes.Code != types.Success { - errMsg = apiRes.Message - } - } +// if toolCall { // 调用函数完成任务 +// params := make(map[string]any) +// _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) +// logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) +// params["user_id"] = userVo.Id +// var apiRes types.BizVo +// r, err := req2.C().R().SetHeader("Body-Type", "application/json"). +// SetHeader("Authorization", function.Token). +// SetBody(params).Post(function.Action) +// errMsg := "" +// if err != nil { +// errMsg = err.Error() +// } else { +// all, _ := io.ReadAll(r.Body) +// err = json.Unmarshal(all, &apiRes) +// if err != nil { +// errMsg = err.Error() +// } else if apiRes.Code != types.Success { +// errMsg = apiRes.Message +// } +// } - if errMsg != "" { - errMsg = "调用函数工具出错:" + errMsg - contents = append(contents, errMsg) - } else { - errMsg = utils.InterfaceToString(apiRes.Data) - contents = append(contents, errMsg) - } - utils.SendChunkMsg(ws, errMsg) - } +// if errMsg != "" { +// errMsg = "调用函数工具出错:" + errMsg +// contents = append(contents, errMsg) +// } else { +// errMsg = utils.InterfaceToString(apiRes.Data) +// contents = append(contents, errMsg) +// } +// messageChan <- map[string]interface{}{ +// "type": "text", +// "body": errMsg, +// } +// } - // 消息发送成功 - if len(contents) > 0 { - usage := Usage{ - Prompt: prompt, - Content: strings.Join(contents, ""), - PromptTokens: 0, - CompletionTokens: 0, - TotalTokens: 0, - } - message.Content = usage.Content - h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt) - } - } else { // 非流式输出 - var respVo OpenAIResVo - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("读取响应失败:%v", body) - } - err = json.Unmarshal(body, &respVo) - if err != nil { - return fmt.Errorf("解析响应失败:%v", body) - } - content := respVo.Choices[0].Message.Content - if strings.HasPrefix(req.Model, "o1-") { - content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) - } - utils.SendChunkMsg(ws, content) - respVo.Usage.Prompt = prompt - respVo.Usage.Content = content - h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now()) - } +// // 消息发送成功 +// if len(contents) > 0 { +// usage := Usage{ +// Prompt: prompt, +// Content: strings.Join(contents, ""), +// PromptTokens: 0, +// CompletionTokens: 0, +// TotalTokens: 0, +// } +// message.Content = usage.Content +// h.saveChatHistory(req, usage, message, session, role, userVo, promptCreatedAt, replyCreatedAt) +// } +// } else { // 非流式输出 +// var respVo OpenAIResVo +// body, err := io.ReadAll(response.Body) +// if err != nil { +// return fmt.Errorf("读取响应失败:%v", body) +// } +// err = json.Unmarshal(body, &respVo) +// if err != nil { +// return fmt.Errorf("解析响应失败:%v", body) +// } +// content := respVo.Choices[0].Message.Content +// if strings.HasPrefix(req.Model, "o1-") { +// content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) +// } +// messageChan <- map[string]interface{}{ +// "type": "text", +// "body": content, +// } +// respVo.Usage.Prompt = prompt +// respVo.Usage.Content = content +// h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, session, role, userVo, promptCreatedAt, time.Now()) +// } - return nil -} +// return nil +// } diff --git a/api/handler/ws_handler.go b/api/handler/ws_handler.go deleted file mode 100644 index 958a1266..00000000 --- a/api/handler/ws_handler.go +++ /dev/null @@ -1,153 +0,0 @@ -package handler - -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// * Copyright 2023 The Geek-AI Authors. All rights reserved. -// * Use of this source code is governed by a Apache-2.0 license -// * that can be found in the LICENSE file. -// * @Author yangjian102621@163.com -// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -import ( - "context" - "geekai/core" - "geekai/core/types" - "geekai/service" - "geekai/store/model" - "geekai/utils" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "gorm.io/gorm" -) - -// Websocket 连接处理 handler - -type WebsocketHandler struct { - BaseHandler - wsService *service.WebsocketService - chatHandler *ChatHandler -} - -func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler { - return &WebsocketHandler{ - BaseHandler: BaseHandler{App: app, DB: db}, - chatHandler: chatHandler, - wsService: s, - } -} - -func (h *WebsocketHandler) Client(c *gin.Context) { - clientProtocols := c.GetHeader("Sec-WebSocket-Protocol") - ws, err := (&websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, - Subprotocols: strings.Split(clientProtocols, ","), - }).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - c.Abort() - return - } - - clientId := c.Query("client_id") - client := types.NewWsClient(ws, clientId) - userId := h.GetLoginUserId(c) - if userId == 0 { - _ = client.Send([]byte("Invalid user_id")) - c.Abort() - return - } - var user model.User - if err := h.DB.Where("id", userId).First(&user).Error; err != nil { - _ = client.Send([]byte("Invalid user_id")) - c.Abort() - return - } - - h.wsService.Clients.Put(clientId, client) - logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) - go func() { - for { - _, msg, err := client.Receive() - if err != nil { - logger.Debugf("close connection: %s", client.Conn.RemoteAddr()) - client.Close() - h.wsService.Clients.Delete(clientId) - break - } - - var message types.InputMessage - err = utils.JsonDecode(string(msg), &message) - if err != nil { - continue - } - - logger.Debugf("Receive a message:%+v", message) - if message.Type == types.MsgTypePing { - utils.SendChannelMsg(client, types.ChPing, "pong") - continue - } - - // 当前只处理聊天消息,其他消息全部丢弃 - var chatMessage types.ChatMessage - err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage) - if err != nil || message.Channel != types.ChChat { - logger.Warnf("invalid message body:%+v", message.Body) - continue - } - var chatRole model.ChatRole - err = h.DB.First(&chatRole, chatMessage.RoleId).Error - if err != nil || !chatRole.Enable { - utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!") - continue - } - // if the role bind a model_id, use role's bind model_id - if chatRole.ModelId > 0 { - chatMessage.ModelId = int(chatRole.ModelId) - } - // get model info - var chatModel model.ChatModel - err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error - if err != nil || !chatModel.Enabled { - utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!") - continue - } - - session := &types.ChatSession{ - ClientIP: c.ClientIP(), - UserId: userId, - } - - // use old chat data override the chat model and role ID - var chat model.ChatItem - h.DB.Where("chat_id", chatMessage.ChatId).First(&chat) - if chat.Id > 0 { - chatModel.Id = chat.ModelId - chatMessage.RoleId = int(chat.RoleId) - } - - session.ChatId = chatMessage.ChatId - session.Tools = chatMessage.Tools - session.Stream = chatMessage.Stream - session.Model.KeyId = chatMessage.ModelId - // 复制模型数据 - err = utils.CopyObject(chatModel, &session.Model) - if err != nil { - logger.Error(err, chatModel) - } - session.Model.Id = chatModel.Id - ctx, cancel := context.WithCancel(context.Background()) - h.chatHandler.ReqCancelFunc.Put(clientId, cancel) - err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client) - if err != nil { - logger.Error(err) - utils.SendAndFlush(client, err.Error()) - } else { - utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd}) - logger.Infof("回答完毕: %v", message.Body) - } - - } - }() -} diff --git a/api/main.go b/api/main.go index 47910b97..3c6db9d2 100644 --- a/api/main.go +++ b/api/main.go @@ -243,6 +243,7 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) { group := s.Engine.Group("/api/chat/") + group.Any("message", h.Chat) group.GET("list", h.List) group.GET("detail", h.Detail) group.POST("update", h.Update) @@ -515,10 +516,6 @@ func main() { group.Any("sse", h.PostTest, h.SseTest) }), fx.Provide(service.NewWebsocketService), - fx.Provide(handler.NewWebsocketHandler), - fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) { - s.Engine.Any("/api/ws", h.Client) - }), fx.Provide(handler.NewPromptHandler), fx.Invoke(func(s *core.AppServer, h *handler.PromptHandler) { group := s.Engine.Group("/api/prompt") diff --git a/web/package.json b/web/package.json index 0ab84fd6..8c40a18c 100644 --- a/web/package.json +++ b/web/package.json @@ -45,8 +45,8 @@ "vue": "^3.2.13", "vue-router": "^4.0.15", "unplugin-auto-import": "^0.18.5", + "@microsoft/fetch-event-source": "^2.0.1", "vue-waterfall-plugin-next": "^2.6.5" - }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.4", diff --git a/web/src/App.vue b/web/src/App.vue index cb690973..2875a867 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -5,13 +5,12 @@