geekai/server/chat_handler.go
2023-03-18 20:20:00 +08:00

145 lines
3.2 KiB
Go

package server
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"math/rand"
"net/http"
"net/url"
"openai/types"
"time"
)
// ChatHandle 处理聊天 WebSocket 请求
func (s *Server) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Fatal(err)
return
}
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
client := NewWsClient(ws)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
return
}
logger.Info(string(message))
// TODO: 根据会话请求,传入不同的用户 ID
err = s.sendMessage("test", string(message), client)
if err != nil {
logger.Error(err)
}
}
}()
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(userId string, text string, ws Client) error {
var r = types.ApiRequest{
Model: "gpt-3.5-turbo",
Temperature: 0.9,
MaxTokens: 1024,
Stream: true,
}
var history []types.Message
if v, ok := s.History[userId]; ok {
history = v
} else {
history = make([]types.Message, 0)
}
r.Messages = append(history, types.Message{
Role: "user",
Content: text,
})
logger.Info("上下文历史消息:%+v", s.History[userId])
requestBody, err := json.Marshal(r)
if err != nil {
return err
}
request, err := http.NewRequest(http.MethodPost, s.Config.OpenAi.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return err
}
// TODO: API KEY 负载均衡
rand.Seed(time.Now().UnixNano())
index := rand.Intn(len(s.Config.OpenAi.ApiKeys))
request.Header.Add("Content-Type", "application/json")
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.Config.OpenAi.ApiKeys[index]))
uri := url.URL{}
proxy, _ := uri.Parse(s.Config.ProxyURL)
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
response, err := client.Do(request)
var retryCount = 3
for err != nil {
if retryCount <= 0 {
return err
}
response, err = client.Do(request)
retryCount--
}
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
reader := bufio.NewReader(response.Body)
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
fmt.Println(err)
break
}
if line == "" {
break
} else if len(line) < 20 {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
fmt.Println(err)
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
continue
} else {
contents = append(contents, responseBody.Choices[0].Delta.Content)
}
// 推送消息到客户端
err = ws.(*WsClient).Send([]byte(responseBody.Choices[0].Delta.Content))
if err != nil {
logger.Error(err)
}
fmt.Print(responseBody.Choices[0].Delta.Content)
if responseBody.Choices[0].FinishReason != "" {
break
}
}
// 追加历史消息
history = append(history, message)
s.History[userId] = history
return nil
}