package server import ( "bufio" "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" "math/rand" "net/http" "net/url" "openai/types" "time" ) // ChatHandle 处理聊天 WebSocket 请求 func (s *Server) ChatHandle(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Fatal(err) return } logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr) client := NewWsClient(ws) go func() { for { _, message, err := client.Receive() if err != nil { logger.Error(err) client.Close() return } logger.Info(string(message)) // TODO: 根据会话请求,传入不同的用户 ID err = s.sendMessage("test", string(message), client) if err != nil { logger.Error(err) } } }() } // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 func (s *Server) sendMessage(userId string, text string, ws Client) error { var r = types.ApiRequest{ Model: "gpt-3.5-turbo", Temperature: 0.9, MaxTokens: 1024, Stream: true, } var history []types.Message if v, ok := s.History[userId]; ok { history = v } else { history = make([]types.Message, 0) } r.Messages = append(history, types.Message{ Role: "user", Content: text, }) logger.Info("上下文历史消息:%+v", s.History[userId]) requestBody, err := json.Marshal(r) if err != nil { return err } request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody)) if err != nil { return err } // TODO: API KEY 负载均衡 rand.Seed(time.Now().UnixNano()) index := rand.Intn(len(s.Config.OpenAi.ApiKeys)) request.Header.Add("Content-Type", "application/json") request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index])) uri := url.URL{} proxy, _ := uri.Parse(s.Config.ProxyURL) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxy), }, } response, err := client.Do(request) var retryCount = 3 for err != nil { if retryCount <= 0 { return err } response, err = client.Do(request) retryCount-- } var message = types.Message{} var contents = make([]string, 0) var responseBody = types.ApiResponse{} reader := bufio.NewReader(response.Body) for { line, err := reader.ReadString('\n') if err != nil && err != io.EOF { fmt.Println(err) break } if line == "" { break } else if len(line) < 20 { continue } err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil { fmt.Println(err) continue } // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role continue } else { contents = append(contents, responseBody.Choices[0].Delta.Content) } // 推送消息到客户端 err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content)) if err != nil { logger.Error(err) } fmt.Print(responseBody.Choices[0].Delta.Content) if responseBody.Choices[0].FinishReason != "" { break } } // 追加历史消息 history = append(history, message) s.History[userId] = history return nil }