geekai/server/chat_handler.go

163 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"math/rand"
"net/http"
"openai/types"
"strings"
"time"
)
// ChatHandle 处理聊天 WebSocket 请求
func (s *Server) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Fatal(err)
return
}
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
client := NewWsClient(ws)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
return
}
logger.Info(string(message))
// TODO: 根据会话请求,传入不同的用户 ID
err = s.sendMessage("test", string(message), client)
if err != nil {
logger.Error(err)
}
}
}()
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(userId string, text string, ws Client) error {
var r = types.ApiRequest{
Model: s.Config.Chat.Model,
Temperature: s.Config.Chat.Temperature,
MaxTokens: s.Config.Chat.MaxTokens,
Stream: true,
}
var history []types.Message
if v, ok := s.History[userId]; ok && s.Config.Chat.EnableContext {
history = v
//logger.Infof("上下文历史消息:%+v", history)
} else {
history = make([]types.Message, 0)
}
r.Messages = append(history, types.Message{
Role: "user",
Content: text,
})
requestBody, err := json.Marshal(r)
if err != nil {
return err
}
request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return err
}
request.Header.Add("Content-Type", "application/json")
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
// TODO: 需要将失败的 Key 移除列表
rand.Seed(time.Now().UnixNano())
var retryCount = 3
var response *http.Response
for retryCount > 0 {
index := rand.Intn(len(s.Config.Chat.ApiKeys))
apiKey := s.Config.Chat.ApiKeys[index]
logger.Infof("Use API KEY: %s", apiKey)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
response, err = s.Client.Do(request)
if err == nil {
break
} else {
logger.Error(err)
}
retryCount--
}
if err != nil {
return err
}
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
reader := bufio.NewReader(response.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
logger.Error(err)
break
}
if line == "" {
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
break
} else if len(line) < 20 {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
logger.Error(line)
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
continue
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
break
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyMessage(types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
}, ws)
}
}
// 追加历史消息
history = append(history, types.Message{
Role: "user",
Content: text,
})
message.Content = strings.Join(contents, "")
history = append(history, message)
s.History[userId] = history
return nil
}
// 回复客户端消息
func replyMessage(message types.WsMessage, client Client) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}