geekai/server/chat_handler.go

194 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"math/rand"
"net/http"
"openai/types"
"strings"
"time"
)
// ChatHandle 处理聊天 WebSocket 请求
func (s *Server) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Fatal(err)
return
}
token := c.Query("token")
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
client := NewWsClient(ws)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
return
}
logger.Info(string(message))
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
err = s.sendMessage(token, string(message), client)
if err != nil {
logger.Error(err)
}
}
}()
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(userId string, text string, ws Client) error {
var r = types.ApiRequest{
Model: s.Config.Chat.Model,
Temperature: s.Config.Chat.Temperature,
MaxTokens: s.Config.Chat.MaxTokens,
Stream: true,
}
var history []types.Message
if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext {
history = v
} else {
history = make([]types.Message, 0)
}
r.Messages = append(history, types.Message{
Role: "user",
Content: text,
})
requestBody, err := json.Marshal(r)
if err != nil {
return err
}
request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return err
}
request.Header.Add("Content-Type", "application/json")
var retryCount = 3
var response *http.Response
var failedKey = ""
for retryCount > 0 {
apiKey := s.getApiKey(failedKey)
if apiKey == "" {
logger.Info("Too many requests, all Api Key is not available")
time.Sleep(time.Second)
continue
}
logger.Infof("Use API KEY: %s", apiKey)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
response, err = s.Client.Do(request)
if err == nil {
break
} else {
logger.Error(err)
failedKey = apiKey
}
retryCount--
}
if err != nil {
return err
}
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
reader := bufio.NewReader(response.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
logger.Error(err)
break
}
if line == "" {
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
break
} else if len(line) < 20 {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
logger.Error(line)
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
continue
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
break
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyMessage(types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
}, ws)
}
}
// 追加历史消息
history = append(history, types.Message{
Role: "user",
Content: text,
})
message.Content = strings.Join(contents, "")
history = append(history, message)
s.History[userId] = history
return nil
}
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
func (s *Server) getApiKey(failedKey string) string {
var keys = make([]string, 0)
for _, v := range s.Config.Chat.ApiKeys {
// 过滤掉刚刚失败的 Key
if v == failedKey {
continue
}
// 获取 API Key 的上次调用时间,控制调用频率
var lastAccess int64
if t, ok := s.ApiKeyAccessStat[v]; ok {
lastAccess = t
}
// 保持每分钟访问不超过 15 次
if time.Now().Unix()-lastAccess <= 4 {
continue
}
keys = append(keys, v)
}
rand.Seed(time.Now().UnixNano())
if len(keys) > 0 {
return keys[rand.Intn(len(keys))]
}
return ""
}
// 回复客户端消息
func replyMessage(message types.WsMessage, client Client) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}