rename Token to User, the chat history function is ready

This commit is contained in:
RockYang
2023-03-28 16:03:41 +08:00
parent 95f9dfa9cb
commit ebc2041e8a
10 changed files with 261 additions and 141 deletions

View File

@@ -12,6 +12,7 @@ import (
"net/http"
"net/url"
"openai/types"
"openai/utils"
"strings"
"time"
)
@@ -28,7 +29,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
sessionId := c.Query("sessionId")
roleKey := c.Query("role")
session := s.ChatSession[sessionId]
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token)
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
client := NewWsClient(ws)
var roles = GetChatRoles()
var chatRole = roles[roleKey]
@@ -36,10 +37,11 @@ func (s *Server) ChatHandle(c *gin.Context) {
c.Abort()
return
}
// 发送打招呼
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
// 加载历史消息,如果历史消息为空则发送打招呼
_, err = GetChatHistory(session.Username, roleKey)
if err != nil {
replyMessage(client, chatRole.HelloMsg, true)
}
go func() {
for {
_, message, err := client.Receive()
@@ -60,15 +62,15 @@ func (s *Server) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error {
token, err := GetToken(session.Token)
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client) error {
user, err := GetUser(session.Username)
if err != nil {
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!")
replyMessage(ws, "当前 user 无效,请使用合法的 user 登录!", false)
return err
}
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!")
if user.MaxCalls > 0 && user.RemainingCalls <= 0 {
replyMessage(ws, "当前 user 点数已经用尽,请充值后再使用或者联系管理员!", false)
return nil
}
var r = types.ApiRequest{
@@ -91,7 +93,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
r.Messages = append(context, types.Message{
Role: "user",
Content: text,
Content: prompt,
})
requestBody, err := json.Marshal(r)
@@ -107,7 +109,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
}
request.Header.Add("Content-Type", "application/json")
var retryCount = 3
var retryCount = 5
var response *http.Response
var failedKey = ""
var failedProxyURL = ""
@@ -145,10 +147,11 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
// 如果三次请求都失败的话,则返回对应的错误信息
if err != nil {
replyError(ws, ErrorMsg)
replyMessage(ws, ErrorMsg, false)
return err
}
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var responseBody = types.ApiResponse{}
@@ -161,58 +164,59 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
}
if line == "" {
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
break
} else if len(line) < 20 {
continue
}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
logger.Error(line)
replyError(ws, ErrorMsg)
if err != nil { // 数据解析出错
logger.Error(err, line)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: false}, ws)
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: false})
continue
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}, ws)
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
break
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, content)
replyMessage(types.WsMessage{
replyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: responseBody.Choices[0].Delta.Content,
IsHelloMsg: false,
}, ws)
})
}
}
// 当前 Token 调用次数减 1
if token.MaxCalls > 0 {
token.RemainingCalls -= 1
_ = PutToken(*token)
_ = response.Body.Close() // 关闭资源
// 当前 Username 调用次数减 1
if user.MaxCalls > 0 {
user.RemainingCalls -= 1
_ = PutUser(*user)
}
// 追加历史消息
context = append(context, types.Message{
Role: "user",
Content: text,
})
// 追加上下文消息
useMsg := types.Message{Role: "user", Content: prompt}
context = append(context, useMsg)
message.Content = strings.Join(contents, "")
context = append(context, message)
// 保存上下文
s.ChatContext[key] = context
_ = response.Body.Close() // 关闭资源
return nil
}
func replyError(ws Client, message string) {
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws)
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
// 追加历史消息
if user.EnableHistory {
err = AppendChatHistory(user.Name, role.Key, useMsg)
if err != nil {
return err
}
err = AppendChatHistory(user.Name, role.Key, message)
}
return err
}
// 随机获取一个 API Key如果请求失败则更换 API Key 重试
@@ -267,8 +271,8 @@ func (s *Server) getProxyURL(failedProxyURL string) string {
return ""
}
// 回复客户端消息
func replyMessage(message types.WsMessage, client Client) {
// 回复客户片段端消息
func replyChunkMessage(client Client, message types.WsMessage) {
msg, err := json.Marshal(message)
if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error())
@@ -279,3 +283,61 @@ func replyMessage(message types.WsMessage, client Client) {
logger.Errorf("Error for reply message: %v", err.Error())
}
}
// 回复客户端一条完整的消息
func replyMessage(ws Client, message string, isHelloMsg bool) {
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message, IsHelloMsg: isHelloMsg})
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: isHelloMsg})
}
func (s *Server) GetChatHistoryHandle(c *gin.Context) {
sessionId := c.GetHeader(types.TokenName)
var data struct {
Role string `json:"role"`
}
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
session := s.ChatSession[sessionId]
history, err := GetChatHistory(session.Username, data.Role)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "No history message"})
return
}
var messages = make([]types.HistoryMessage, 0)
role, err := GetChatRole(data.Role)
if err == nil {
// 先将打招呼的消息追加上去
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: role.HelloMsg,
})
for _, v := range history {
if v.Role == "user" {
messages = append(messages, types.HistoryMessage{
Type: "prompt",
Id: utils.RandString(32),
Icon: "images/avatar/user.png",
Content: v.Content,
})
} else if v.Role == "assistant" {
messages = append(messages, types.HistoryMessage{
Type: "reply",
Id: utils.RandString(32),
Icon: role.Icon,
Content: v.Content,
})
}
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: messages})
}