geekai/server/chat_handler.go

409 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"math/rand"
"net/http"
"net/url"
"openai/types"
"openai/utils"
"strings"
"time"
)
const ErrorMsg = "抱歉AI 助手开小差了,请马上联系管理员去盘它。"
// ChatHandle 处理聊天 WebSocket 请求
func (s *Server) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Fatal(err)
return
}
sessionId := c.Query("sessionId")
roleKey := c.Query("role")
session := s.ChatSession[sessionId]
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := NewWsClient(ws)
var roles = GetChatRoles()
var chatRole = roles[roleKey]
if !chatRole.Enable { // 角色未启用
c.Abort()
return
}
// 加载历史消息,如果历史消息为空则发送打招呼消息
_, err = GetChatHistory(session.Username, roleKey)
if err != nil {
replyMessage(client, chatRole.HelloMsg, true)
}
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
return
}
logger.Info("Receive a message: ", string(message))
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
err = s.sendMessage(session, chatRole, string(message), client, false)
if err != nil {
logger.Error(err)
}
}
}()
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client, resetContext bool) error {
user, err := GetUser(session.Username)
if err != nil {
replyMessage(ws, "当前 user 无效,请使用合法的 user 登录!", false)
return err
}
if user.MaxCalls > 0 && user.RemainingCalls <= 0 {
replyMessage(ws, "当前 TOKEN 点数已经用尽,加入我们的知识星球可以免费领取点卡!", false)
replyMessage(ws, "![](images/start.png)", true)
return nil
}
var r = types.ApiRequest{
Model: s.Config.Chat.Model,
Temperature: s.Config.Chat.Temperature,
MaxTokens: s.Config.Chat.MaxTokens,
Stream: true,
}
var context []types.Message
var ctxKey = fmt.Sprintf("%s-%s", session.SessionId, role.Key)
if v, ok := s.ChatContexts[ctxKey]; ok && s.Config.Chat.EnableContext {
context = v.Messages
} else {
context = role.Context
}
if s.DebugMode {
logger.Infof("会话上下文:%+v", context)
}
// 创建 HttpClient 请求对象
var client *http.Client
var retryCount = 5
var response *http.Response
var apiKey string
var failedKey = ""
var failedProxyURL = ""
for retryCount > 0 {
r.Messages = append(context, types.Message{
Role: "user",
Content: prompt,
})
requestBody, err := json.Marshal(r)
if err != nil {
return err
}
request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return err
}
request.Header.Add("Content-Type", "application/json")
proxyURL := s.getProxyURL(failedProxyURL)
if proxyURL == "" {
client = &http.Client{}
} else { // 使用代理
uri := url.URL{}
proxy, _ := uri.Parse(proxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
}
apiKey = s.getApiKey(failedKey)
if apiKey == "" {
logger.Info("Too many requests, all Api Key is not available")
time.Sleep(time.Second)
continue
}
logger.Infof("Use API KEY: %s, PROXY: %s", apiKey, proxyURL)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
response, err = client.Do(request)
if err == nil {
break
} else {
failedKey = apiKey
failedProxyURL = proxyURL
}
retryCount--
}
// 如果三次请求都失败的话,则返回对应的错误信息
if err != nil {
replyMessage(ws, ErrorMsg, false)
replyMessage(ws, "![](images/wx.png)", true)
return err
}
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
reader := bufio.NewReader(response.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
logger.Error(err)
break
}
if len(line) < 20 {
continue
} else if strings.Contains(line, "This key is associated with a deactivated account") || // 账号被禁用
strings.Contains(line, "You exceeded your current quota") { // 当前 KEY 余额被用尽
logger.Infof("API Key %s is deactivated", apiKey)
// 移除当前 API key
for i, v := range s.Config.Chat.ApiKeys {
if v == apiKey {
s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys[:i], s.Config.Chat.ApiKeys[i+1:]...)
}
}
// 更新配置文档
_ = utils.SaveConfig(s.Config, s.ConfigPath)
// 重发当前消息
return s.sendMessage(session, role, prompt, ws, false)
// 上下文超出长度了
} else if strings.Contains(line, "This model's maximum context length is 4097 tokens") {
logger.Info("会话上下文长度超出限制, Username: %s", user.Name)
// 重置上下文,重发当前消息
delete(s.ChatContexts, ctxKey)
return s.sendMessage(session, role, prompt, ws, true)
} else if !strings.Contains(line, "data:") {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
logger.Error(err, line)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
replyMessage(ws, ErrorMsg, false)
replyMessage(ws, "![](images/wx.png)", true)
break
}
if resetContext {
replyMessage(ws, "温馨提示:当前会话上下文长度超出限制,已为您重置会话上下文!", false)
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: false})
continue
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
IsHelloMsg: false,
})
}
}
_ = response.Body.Close() // 关闭资源
// 消息发送成功
if len(contents) > 0 {
// 当前 Username 调用次数减 1
if user.MaxCalls > 0 {
user.RemainingCalls -= 1
_ = PutUser(*user)
}
if message.Role == "" {
message.Role = "assistant"
}
// 追加上下文消息
useMsg := types.Message{Role: "user", Content: prompt}
context = append(context, useMsg)
message.Content = strings.Join(contents, "")
// 更新上下文消息
if s.Config.Chat.EnableContext {
context = append(context, message)
s.ChatContexts[ctxKey] = types.ChatContext{
Messages: context,
LastAccessTime: time.Now().Unix(),
}
}
// 追加历史消息
if user.EnableHistory {
err = AppendChatHistory(user.Name, role.Key, useMsg)
if err != nil {
return err
}
err = AppendChatHistory(user.Name, role.Key, message)
}
}
return err
}
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
func (s *Server) getApiKey(failedKey string) string {
var keys = make([]string, 0)
for _, v := range s.Config.Chat.ApiKeys {
// 过滤掉刚刚失败的 Key
if v == failedKey {
continue
}
// 获取 API Key 的上次调用时间,控制调用频率
var lastAccess int64
if t, ok := s.ApiKeyAccessStat[v]; ok {
lastAccess = t
}
// 保持每分钟访问不超过 15 次
if time.Now().Unix()-lastAccess <= 4 {
continue
}
keys = append(keys, v)
}
rand.NewSource(time.Now().UnixNano())
if len(keys) > 0 {
key := keys[rand.Intn(len(keys))]
s.ApiKeyAccessStat[key] = time.Now().Unix()
return key
}
return ""
}
// 获取一个可用的代理
func (s *Server) getProxyURL(failedProxyURL string) string {
if len(s.Config.ProxyURL) == 0 {
return ""
}
if len(s.Config.ProxyURL) == 1 || failedProxyURL == "" {
return s.Config.ProxyURL[0]
}
for i, v := range s.Config.ProxyURL {
if failedProxyURL == v {
if i == len(s.Config.ProxyURL)-1 {
return s.Config.ProxyURL[0]
} else {
return s.Config.ProxyURL[i+1]
}
}
}
return ""
}
// 回复客户片段端消息
func replyChunkMessage(client Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
return
}
err = client.(*WsClient).Send(msg)
if err != nil {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws Client, message string, isHelloMsg bool) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: isHelloMsg})
}
func (s *Server) GetChatHistoryHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
var data struct {
Role string `json:"role"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
session := s.ChatSession[sessionId]
history, err := GetChatHistory(session.Username, data.Role)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "No history message"})
return
}
var messages = make([]types.HistoryMessage, 0)
role, err := GetChatRole(data.Role)
if err == nil {
// 先将打招呼的消息追加上去
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: role.HelloMsg,
})
for _, v := range history {
if v.Role == "user" {
messages = append(messages, types.HistoryMessage{
Type: "prompt",
Id: utils.RandString(32),
Icon: "images/avatar/user.png",
Content: v.Content,
})
} else if v.Role == "assistant" {
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: v.Content,
})
}
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: messages})
}
// ClearHistoryHandle 清空聊天记录
func (s *Server) ClearHistoryHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
var data struct {
Role string `json:"role"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
session := s.ChatSession[sessionId]
err = ClearChatHistory(session.Username, data.Role)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to remove data from DB"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
}