mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
rename Token to User, the chat history function is ready
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"openai/types"
|
||||
"openai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -28,7 +29,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
sessionId := c.Query("sessionId")
|
||||
roleKey := c.Query("role")
|
||||
session := s.ChatSession[sessionId]
|
||||
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token)
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.Request.RemoteAddr, session.Username)
|
||||
client := NewWsClient(ws)
|
||||
var roles = GetChatRoles()
|
||||
var chatRole = roles[roleKey]
|
||||
@@ -36,10 +37,11 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 发送打招呼信息
|
||||
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
|
||||
// 加载历史消息,如果历史消息为空则发送打招呼消息
|
||||
_, err = GetChatHistory(session.Username, roleKey)
|
||||
if err != nil {
|
||||
replyMessage(client, chatRole.HelloMsg, true)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
_, message, err := client.Receive()
|
||||
@@ -60,15 +62,15 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error {
|
||||
token, err := GetToken(session.Token)
|
||||
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, prompt string, ws Client) error {
|
||||
user, err := GetUser(session.Username)
|
||||
if err != nil {
|
||||
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!")
|
||||
replyMessage(ws, "当前 user 无效,请使用合法的 user 登录!", false)
|
||||
return err
|
||||
}
|
||||
|
||||
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
|
||||
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!")
|
||||
if user.MaxCalls > 0 && user.RemainingCalls <= 0 {
|
||||
replyMessage(ws, "当前 user 点数已经用尽,请充值后再使用或者联系管理员!", false)
|
||||
return nil
|
||||
}
|
||||
var r = types.ApiRequest{
|
||||
@@ -91,7 +93,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
|
||||
|
||||
r.Messages = append(context, types.Message{
|
||||
Role: "user",
|
||||
Content: text,
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
requestBody, err := json.Marshal(r)
|
||||
@@ -107,7 +109,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
|
||||
}
|
||||
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
var retryCount = 3
|
||||
var retryCount = 5
|
||||
var response *http.Response
|
||||
var failedKey = ""
|
||||
var failedProxyURL = ""
|
||||
@@ -145,10 +147,11 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
|
||||
|
||||
// 如果三次请求都失败的话,则返回对应的错误信息
|
||||
if err != nil {
|
||||
replyError(ws, ErrorMsg)
|
||||
replyMessage(ws, ErrorMsg, false)
|
||||
return err
|
||||
}
|
||||
|
||||
// 循环读取 Chunk 消息
|
||||
var message = types.Message{}
|
||||
var contents = make([]string, 0)
|
||||
var responseBody = types.ApiResponse{}
|
||||
@@ -161,58 +164,59 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||
break
|
||||
} else if len(line) < 20 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil {
|
||||
logger.Error(line)
|
||||
replyError(ws, ErrorMsg)
|
||||
if err != nil { // 数据解析出错
|
||||
logger.Error(err, line)
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
|
||||
break
|
||||
}
|
||||
// 初始化 role
|
||||
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
|
||||
message.Role = responseBody.Choices[0].Delta.Role
|
||||
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: false}, ws)
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: false})
|
||||
continue
|
||||
} else if responseBody.Choices[0].FinishReason != "" { // 输出完成或者输出中断了
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: false}, ws)
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: false})
|
||||
break
|
||||
} else {
|
||||
content := responseBody.Choices[0].Delta.Content
|
||||
contents = append(contents, content)
|
||||
replyMessage(types.WsMessage{
|
||||
replyChunkMessage(ws, types.WsMessage{
|
||||
Type: types.WsMiddle,
|
||||
Content: responseBody.Choices[0].Delta.Content,
|
||||
IsHelloMsg: false,
|
||||
}, ws)
|
||||
})
|
||||
}
|
||||
}
|
||||
// 当前 Token 调用次数减 1
|
||||
if token.MaxCalls > 0 {
|
||||
token.RemainingCalls -= 1
|
||||
_ = PutToken(*token)
|
||||
_ = response.Body.Close() // 关闭资源
|
||||
|
||||
// 当前 Username 调用次数减 1
|
||||
if user.MaxCalls > 0 {
|
||||
user.RemainingCalls -= 1
|
||||
_ = PutUser(*user)
|
||||
}
|
||||
// 追加历史消息
|
||||
context = append(context, types.Message{
|
||||
Role: "user",
|
||||
Content: text,
|
||||
})
|
||||
// 追加上下文消息
|
||||
useMsg := types.Message{Role: "user", Content: prompt}
|
||||
context = append(context, useMsg)
|
||||
message.Content = strings.Join(contents, "")
|
||||
context = append(context, message)
|
||||
// 保存上下文
|
||||
s.ChatContext[key] = context
|
||||
_ = response.Body.Close() // 关闭资源
|
||||
return nil
|
||||
}
|
||||
|
||||
func replyError(ws Client, message string) {
|
||||
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
|
||||
// 追加历史消息
|
||||
if user.EnableHistory {
|
||||
err = AppendChatHistory(user.Name, role.Key, useMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AppendChatHistory(user.Name, role.Key, message)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 随机获取一个 API Key,如果请求失败,则更换 API Key 重试
|
||||
@@ -267,8 +271,8 @@ func (s *Server) getProxyURL(failedProxyURL string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 回复客户端消息
|
||||
func replyMessage(message types.WsMessage, client Client) {
|
||||
// 回复客户片段端消息
|
||||
func replyChunkMessage(client Client, message types.WsMessage) {
|
||||
msg, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||
@@ -279,3 +283,61 @@ func replyMessage(message types.WsMessage, client Client) {
|
||||
logger.Errorf("Error for reply message: %v", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 回复客户端一条完整的消息
|
||||
func replyMessage(ws Client, message string, isHelloMsg bool) {
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsStart, IsHelloMsg: isHelloMsg})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message, IsHelloMsg: isHelloMsg})
|
||||
replyChunkMessage(ws, types.WsMessage{Type: types.WsEnd, IsHelloMsg: isHelloMsg})
|
||||
}
|
||||
|
||||
func (s *Server) GetChatHistoryHandle(c *gin.Context) {
|
||||
sessionId := c.GetHeader(types.TokenName)
|
||||
var data struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
||||
return
|
||||
}
|
||||
|
||||
session := s.ChatSession[sessionId]
|
||||
history, err := GetChatHistory(session.Username, data.Role)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "No history message"})
|
||||
return
|
||||
}
|
||||
|
||||
var messages = make([]types.HistoryMessage, 0)
|
||||
role, err := GetChatRole(data.Role)
|
||||
if err == nil {
|
||||
// 先将打招呼的消息追加上去
|
||||
messages = append(messages, types.HistoryMessage{
|
||||
Type: "reply",
|
||||
Id: utils.RandString(32),
|
||||
Icon: role.Icon,
|
||||
Content: role.HelloMsg,
|
||||
})
|
||||
|
||||
for _, v := range history {
|
||||
if v.Role == "user" {
|
||||
messages = append(messages, types.HistoryMessage{
|
||||
Type: "prompt",
|
||||
Id: utils.RandString(32),
|
||||
Icon: "images/avatar/user.png",
|
||||
Content: v.Content,
|
||||
})
|
||||
} else if v.Role == "assistant" {
|
||||
messages = append(messages, types.HistoryMessage{
|
||||
Type: "reply",
|
||||
Id: utils.RandString(32),
|
||||
Icon: role.Icon,
|
||||
Content: v.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: messages})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user