diff --git a/server/chat_handler.go b/server/chat_handler.go index 19929d41..8f4323a4 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -3,6 +3,7 @@ package server import ( "bufio" "bytes" + "context" "encoding/json" "errors" "fmt" @@ -59,12 +60,13 @@ func (s *Server) ChatHandle(c *gin.Context) { delete(s.ChatClients, sessionId) return } - logger.Info("Receive a message: ", string(message)) //replyMessage(client, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false) //replyMessage(client, "![](images/wx.png)", true) - // TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录 - err = s.sendMessage(session, chatRole, string(message), client, false) + ctx, cancel := context.WithCancel(context.Background()) + s.ReqCancelFunc[sessionId] = cancel + // 回复消息 + err = s.sendMessage(ctx, session, chatRole, string(message), client, false) if err != nil { logger.Error(err) } @@ -73,7 +75,13 @@ func (s *Server) ChatHandle(c *gin.Context) { } // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 -func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error { +func (s *Server) sendMessage(ctx context.Context, session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error { + cancel := s.ReqCancelFunc[session.SessionId] + defer func() { + cancel() + delete(s.ReqCancelFunc, session.SessionId) + }() + user, err := GetUser(session.Username) if err != nil { replyMessage(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!", false) @@ -98,38 +106,38 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro replyMessage(ws, "![](images/start.png)", true) return nil } - var r = types.ApiRequest{ + var req = types.ApiRequest{ Model: s.Config.Chat.Model, Temperature: s.Config.Chat.Temperature, MaxTokens: s.Config.Chat.MaxTokens, Stream: true, } - var context []types.Message + var chatCtx []types.Message var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key) if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext { - context = v.Messages + chatCtx = v.Messages } else { - context = role.Context + chatCtx = role.Context } if s.DebugMode { - logger.Infof("会话上下文:%+v", context) + logger.Infof("会话上下文:%+v", chatCtx) } + req.Messages = append(chatCtx, types.Message{ + Role: "user", + Content: prompt, + }) + // 创建 HttpClient 请求对象 var client *http.Client - var retryCount = 5 + var retryCount = 5 // 重试次数 var response *http.Response var apiKey string var failedKey = "" var failedProxyURL = "" for retryCount > 0 { - r.Messages = append(context, types.Message{ - Role: "user", - Content: prompt, - }) - - requestBody, err := json.Marshal(r) + requestBody, err := json.Marshal(req) if err != nil { return err } @@ -139,6 +147,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro return err } + request = request.WithContext(ctx) request.Header.Add("Content-Type", "application/json") proxyURL := s.getProxyURL(failedProxyURL) @@ -164,7 +173,10 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro response, err = client.Do(request) if err == nil { break + } else if strings.Contains(err.Error(), "context canceled") { + return errors.New("用户取消了请求:" + prompt) } else { + logger.Error("HTTP API 请求失败:" + err.Error()) failedKey = apiKey failedProxyURL = proxyURL } @@ -205,14 +217,14 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro _ = utils.SaveConfig(s.Config, s.ConfigPath) // 重发当前消息 - return s.sendMessage(session, role, prompt, ws, false) + return s.sendMessage(ctx, session, role, prompt, ws, false) // 上下文超出长度了 } else if strings.Contains(line, "This model's maximum context length is 4097 tokens") { logger.Infof("会话上下文长度超出限制, Username: %s", user.Name) // 重置上下文,重发当前消息 delete(s.ChatContexts, ctxKey) - return s.sendMessage(session, role, prompt, ws, true) + return s.sendMessage(ctx, session, role, prompt, ws, true) } else if !strings.Contains(line, "data:") { continue } @@ -246,7 +258,18 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro IsHelloMsg: false, }) } - } + + // 监控取消信号 + select { + case <-ctx.Done(): + // 结束输出 + replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}) + _ = response.Body.Close() + return errors.New("用户取消了请求:" + prompt) + default: + continue + } + } // end for _ = response.Body.Close() // 关闭资源 // 消息发送成功 @@ -260,27 +283,26 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro if message.Role == "" { message.Role = "assistant" } - // 追加上下文消息 - useMsg := types.Message{Role: "user", Content: prompt} - context = append(context, useMsg) message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息 if s.Config.Chat.EnableContext { - context = append(context, message) + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 s.ChatContexts[ctxKey] = types.ChatContext{ - Messages: context, + Messages: chatCtx, LastAccessTime: time.Now().Unix(), } } // 追加历史消息 if user.EnableHistory { - err = AppendChatHistory(user.Name, role.Key, useMsg) + err = AppendChatHistory(user.Name, role.Key, useMsg) // 提问消息 if err != nil { return err } - err = AppendChatHistory(user.Name, role.Key, message) + err = AppendChatHistory(user.Name, role.Key, message) // 回复消息 } } @@ -431,3 +453,12 @@ func (s *Server) ClearHistoryHandle(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) } + +// StopGenerateHandle 停止生成 +func (s *Server) StopGenerateHandle(c *gin.Context) { + sessionId := c.GetHeader(types.TokenName) + cancel := s.ReqCancelFunc[sessionId] + cancel() + delete(s.ReqCancelFunc, sessionId) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) +} diff --git a/server/server.go b/server/server.go index b0809c2e..d4e194d9 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "embed" "encoding/json" "github.com/gin-contrib/sessions" @@ -39,10 +40,11 @@ type Server struct { // 保存 Websocket 会话 Username, 每个 Username 只能连接一次 // 防止第三方直接连接 socket 调用 OpenAI API - ChatSession map[string]types.ChatSession //map[sessionId]User - ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 - ChatClients map[string]*WsClient // Websocket 连接集合 - DebugMode bool // 是否开启调试模式 + ChatSession map[string]types.ChatSession //map[sessionId]User + ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内 + ChatClients map[string]*WsClient // Websocket 连接集合 + ReqCancelFunc map[string]context.CancelFunc // HttpClient 请求取消 handle function + DebugMode bool // 是否开启调试模式 } func NewServer(configPath string) (*Server, error) { @@ -67,6 +69,7 @@ func NewServer(configPath string) (*Server, error) { ChatContexts: make(map[string]types.ChatContext, 16), ChatSession: make(map[string]types.ChatSession), ChatClients: make(map[string]*WsClient), + ReqCancelFunc: make(map[string]context.CancelFunc), ApiKeyAccessStat: make(map[string]int64), }, nil } @@ -83,17 +86,18 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { engine.Use(AuthorizeMiddleware(s)) engine.Use(Recover) - engine.POST("/test", s.TestHandle) - engine.GET("/api/session/get", s.GetSessionHandle) - engine.POST("/api/login", s.LoginHandle) - engine.POST("/api/logout", s.LogoutHandle) - engine.Any("/api/chat", s.ChatHandle) + engine.POST("test", s.TestHandle) + engine.GET("api/session/get", s.GetSessionHandle) + engine.POST("api/login", s.LoginHandle) + engine.POST("api/logout", s.LogoutHandle) + engine.Any("api/chat", s.ChatHandle) + engine.POST("api/chat/stop", s.StopGenerateHandle) engine.POST("api/chat/history", s.GetChatHistoryHandle) engine.POST("api/chat/history/clear", s.ClearHistoryHandle) - engine.POST("/api/config/set", s.ConfigSetHandle) - engine.GET("/api/config/chat-roles/get", s.GetChatRoleListHandle) - engine.GET("/api/config/chat-roles/add", s.AddChatRoleHandle) + engine.POST("api/config/set", s.ConfigSetHandle) + engine.GET("api/config/chat-roles/get", s.GetChatRoleListHandle) + engine.GET("api/config/chat-roles/add", s.AddChatRoleHandle) engine.POST("api/config/user/add", s.AddUserHandle) engine.POST("api/config/user/batch-add", s.BatchAddUserHandle) engine.POST("api/config/user/set", s.SetUserHandle) diff --git a/test/test.go b/test/test.go index a3ba164a..6b6992c2 100644 --- a/test/test.go +++ b/test/test.go @@ -1,10 +1,73 @@ package main import ( + "context" "fmt" + "io" + "net/http" "time" ) func main() { + + ctx, cancel := context.WithCancel(context.Background()) + + http.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { + cancel() + _, _ = fmt.Fprintf(w, "请求取消!") + }) + + go func() { + err := http.ListenAndServe(":9999", nil) + if err != nil { + return + } + }() + + testHttpClient(ctx) +} + +// Http client 取消操作 +func testHttpClient(ctx context.Context) { + + req, err := http.NewRequest("GET", "http://localhost:2345", nil) + if err != nil { + fmt.Println(err) + return + } + + req = req.WithContext(ctx) + + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + fmt.Println(err) + return + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + for { + time.Sleep(time.Second) + fmt.Println(time.Now()) + select { + case <-ctx.Done(): + fmt.Println("取消退出") + return + default: + continue + } + } + + if err != nil { + fmt.Println(err) + return + } + + fmt.Println(string(body)) + +} + +func testDate() { fmt.Println(time.Unix(1683336167, 0).Format("2006-01-02 15:04:05")) } diff --git a/web/.env.production b/web/.env.production index a21ed00e..3ff9dfd5 100644 --- a/web/.env.production +++ b/web/.env.production @@ -1,2 +1,2 @@ -VUE_APP_API_HOST=https://ai.r9it.com -VUE_APP_WS_HOST=wss://ai.r9it.com +VUE_APP_API_HOST=https://www.chat-plus.net +VUE_APP_WS_HOST=wss://www.chat-plus.net diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 55feb2b6..cdb90c79 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -81,9 +81,26 @@ :icon="chat.icon" :content="chat.content"/> - +
+
+ + + + + 停止生成 + + + + + + + 重新生成 + +
+
+ this.$refs["text-input"].focus(), 100); @@ -494,6 +548,26 @@ export default defineComponent({ }).catch(() => { ElMessage.error("注销失败"); }) + }, + + // 停止生成 + stopGenerate: function () { + this.showStopGenerate = false; + httpPost("/api/chat/stop").then(() => { + console.log("stopped generate.") + this.sending = false; + if (this.canReGenerate) { + this.showReGenerate = true; + } + }) + }, + + // 重新生成 + reGenerate: function () { + this.sending = true; + this.showStopGenerate = true; + this.showReGenerate = false; + this.socket.send('重新生成上述问题的答案:' + this.previousText); } }, @@ -676,6 +750,24 @@ export default defineComponent({ } } + .re-generate { + position: relative; + display: flex; + justify-content: center; + + .btn-box { + position absolute + bottom 10px; + + .el-button { + .el-icon { + margin-right 5px; + } + } + + } + } + .chat-tool-box { padding 10px; border-top: 1px solid #2F3032