diff --git a/server/chat_handler.go b/server/chat_handler.go index 9699176a..fbf65a6d 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "io" "math/rand" "net/http" "net/url" @@ -92,29 +91,31 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro logger.Infof("会话上下文:%+v", context) } - r.Messages = append(context, types.Message{ - Role: "user", - Content: prompt, - }) - - requestBody, err := json.Marshal(r) - if err != nil { - return err - } - // 创建 HttpClient 请求对象 var client *http.Client - request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody)) - if err != nil { - return err - } - - request.Header.Add("Content-Type", "application/json") var retryCount = 5 var response *http.Response + var apiKey string var failedKey = "" var failedProxyURL = "" for retryCount > 0 { + r.Messages = append(context, types.Message{ + Role: "user", + Content: prompt, + }) + + requestBody, err := json.Marshal(r) + if err != nil { + return err + } + + request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return err + } + + request.Header.Add("Content-Type", "application/json") + proxyURL := s.getProxyURL(failedProxyURL) if proxyURL == "" { client = &http.Client{} @@ -127,7 +128,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro }, } } - apiKey := s.getApiKey(failedKey) + apiKey = s.getApiKey(failedKey) if apiKey == "" { logger.Info("Too many requests, all Api Key is not available") time.Sleep(time.Second) @@ -139,7 +140,12 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro if err == nil { break } else { - logger.Error(err) + // 上下文超出长度了 + if strings.Contains(err.Error(), "This model's maximum context length is 4097 tokens") { + logger.Info("会话上下文长度超出限制, Username: %s", user.Name) + replyMessage(ws, "温馨提示:会话上下文长度超出限制,已为您重置会话上下文!", false) + context = role.Context + } failedKey = apiKey failedProxyURL = proxyURL } @@ -160,16 +166,23 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, pro reader := bufio.NewReader(response.Body) for { line, err := reader.ReadString('\n') - if err != nil && err != io.EOF { + if err != nil { logger.Error(err) break } - - if line == "" { - replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) - break - } else if len(line) < 20 { + if len(line) < 20 { continue + } else if strings.Contains(line, "This key is associated with a deactivated account") { + logger.Infof("API Key %s is deactivated", apiKey) + // 移除当前 API key + for i, v := range s.Config.Chat.ApiKeys { + if v == apiKey { + s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys[:i], s.Config.Chat.ApiKeys[i+1:]...) + } + } + + // 重发当前消息 + return s.sendMessage(session, role, prompt, ws) } err = json.Unmarshal([]byte(line[6:]), &responseBody)