package server import ( "bufio" "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" "math/rand" "net/http" "net/url" "openai/types" "strings" "time" ) const ErrorMsg = "抱歉,AI 助手开小差了,我马上找人去盘它。" // ChatHandle 处理聊天 WebSocket 请求 func (s *Server) ChatHandle(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Fatal(err) return } sessionId := c.Query("sessionId") roleKey := c.Query("role") session := s.ChatSession[sessionId] logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token) client := NewWsClient(ws) var roles = GetChatRoles() var chatRole = roles[roleKey] if !chatRole.Enable { // 角色未启用 c.Abort() return } // 发送打招呼信息 replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client) replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client) replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client) go func() { for { _, message, err := client.Receive() if err != nil { logger.Error(err) client.Close() return } logger.Info("Receive a message: ", string(message)) // TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录 err = s.sendMessage(session, chatRole, string(message), client) if err != nil { logger.Error(err) } } }() } // 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端 func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error { token, err := GetToken(session.Token) if err != nil { replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!") return err } if token.MaxCalls > 0 && token.RemainingCalls <= 0 { replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用!") return nil } var r = types.ApiRequest{ Model: s.Config.Chat.Model, Temperature: s.Config.Chat.Temperature, MaxTokens: s.Config.Chat.MaxTokens, Stream: true, } var context []types.Message var key = session.SessionId + role.Name if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext { context = v } else { context = role.Context } if s.DebugMode { logger.Infof("会话上下文:%+v", context) } r.Messages = append(context, types.Message{ Role: "user", Content: text, }) requestBody, err := json.Marshal(r) if err != nil { return err } // 创建 HttpClient 请求对象 var client *http.Client request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody)) if err != nil { return err } request.Header.Add("Content-Type", "application/json") var retryCount = 3 var response *http.Response var failedKey = "" var failedProxyURL = "" for retryCount > 0 { proxyURL := s.getProxyURL(failedProxyURL) if proxyURL == "" { client = &http.Client{} } else { // 使用代理 uri := url.URL{} proxy, _ := uri.Parse(proxyURL) client = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxy), }, } } apiKey := s.getApiKey(failedKey) if apiKey == "" { logger.Info("Too many requests, all Api Key is not available") time.Sleep(time.Second) continue } logger.Infof("Use API KEY: %s", apiKey) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) response, err = client.Do(request) if err == nil { break } else { logger.Error(err) failedKey = apiKey failedProxyURL = proxyURL } retryCount-- } // 如果三次请求都失败的话,则返回对应的错误信息 if err != nil { replyError(ws, ErrorMsg) return err } var message = types.Message{} var contents = make([]string, 0) var responseBody = types.ApiResponse{} reader := bufio.NewReader(response.Body) for { line, err := reader.ReadString('\n') if err != nil && err != io.EOF { logger.Error(err) break } if line == "" { replyMessage(types.WsMessage{Type: types.WsEnd}, ws) break } else if len(line) < 20 { continue } err = json.Unmarshal([]byte(line[6:]), &responseBody) if err != nil { logger.Error(line) replyError(ws, ErrorMsg) break } // 初始化 role if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { message.Role = responseBody.Choices[0].Delta.Role replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: false}, ws) continue } else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了 replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}, ws) break } else { content := responseBody.Choices[0].Delta.Content contents = append(contents, content) replyMessage(types.WsMessage{ Type: types.WsMiddle, Content: responseBody.Choices[0].Delta.Content, IsHelloMsg: false, }, ws) } } // 当前 Token 调用次数减 1 if token.MaxCalls > 0 { token.RemainingCalls -= 1 } // 追加历史消息 context = append(context, types.Message{ Role: "user", Content: text, }) message.Content = strings.Join(contents, "") context = append(context, message) // 保存上下文 s.ChatContext[key] = context _ = response.Body.Close() // 关闭资源 return nil } func replyError(ws Client, message string) { replyMessage(types.WsMessage{Type: types.WsStart}, ws) replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws) replyMessage(types.WsMessage{Type: types.WsEnd}, ws) } // 随机获取一个 API Key,如果请求失败,则更换 API Key 重试 func (s *Server) getApiKey(failedKey string) string { var keys = make([]string, 0) for _, v := range s.Config.Chat.ApiKeys { // 过滤掉刚刚失败的 Key if v == failedKey { continue } // 获取 API Key 的上次调用时间,控制调用频率 var lastAccess int64 if t, ok := s.ApiKeyAccessStat[v]; ok { lastAccess = t } // 保持每分钟访问不超过 15 次 if time.Now().Unix()-lastAccess <= 4 { continue } keys = append(keys, v) } rand.Seed(time.Now().UnixNano()) if len(keys) > 0 { key := keys[rand.Intn(len(keys))] s.ApiKeyAccessStat[key] = time.Now().Unix() return key } return "" } // 获取一个可用的代理 func (s *Server) getProxyURL(failedProxyURL string) string { if len(s.Config.ProxyURL) == 0 { return "" } if len(s.Config.ProxyURL) == 1 || failedProxyURL == "" { return s.Config.ProxyURL[0] } for i, v := range s.Config.ProxyURL { if failedProxyURL == v { if i == len(s.Config.ProxyURL)-1 { return s.Config.ProxyURL[0] } else { return s.Config.ProxyURL[i+1] } } } return "" } // 回复客户端消息 func replyMessage(message types.WsMessage, client Client) { msg, err := json.Marshal(message) if err != nil { logger.Errorf("Error for decoding json data: %v", err.Error()) return } err = client.(*WsClient).Send(msg) if err != nil { logger.Errorf("Error for reply message: %v", err.Error()) } }