mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 18:53:43 +08:00
支持 TOKEN 设置最大调用次数
This commit is contained in:
@@ -16,6 +16,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const ErrorMsg = "抱歉,AI 助手开小差了,我马上找人去盘它。"
|
||||
|
||||
// ChatHandle 处理聊天 WebSocket 请求
|
||||
func (s *Server) ChatHandle(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
@@ -24,16 +26,19 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
sessionId := c.Query("sessionId")
|
||||
role := c.Query("role")
|
||||
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
|
||||
roleKey := c.Query("role")
|
||||
session := s.ChatSession[sessionId]
|
||||
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token)
|
||||
client := NewWsClient(ws)
|
||||
if !s.ChatRoles[role].Enable { // 角色未启用
|
||||
var roles = GetChatRoles()
|
||||
var chatRole = roles[roleKey]
|
||||
if !chatRole.Enable { // 角色未启用
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 发送打招呼信息
|
||||
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
|
||||
go func() {
|
||||
for {
|
||||
@@ -46,7 +51,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
|
||||
err = s.sendMessage(sessionId, role, string(message), client)
|
||||
err = s.sendMessage(session, chatRole, string(message), client)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -55,7 +60,17 @@ func (s *Server) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||
func (s *Server) sendMessage(sessionId string, role string, text string, ws Client) error {
|
||||
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error {
|
||||
token, err := GetToken(session.Token)
|
||||
if err != nil {
|
||||
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!")
|
||||
return err
|
||||
}
|
||||
|
||||
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
|
||||
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用!")
|
||||
return nil
|
||||
}
|
||||
var r = types.ApiRequest{
|
||||
Model: s.Config.Chat.Model,
|
||||
Temperature: s.Config.Chat.Temperature,
|
||||
@@ -63,11 +78,11 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
Stream: true,
|
||||
}
|
||||
var context []types.Message
|
||||
var key = sessionId + role
|
||||
var key = session.SessionId + role.Name
|
||||
if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext {
|
||||
context = v
|
||||
} else {
|
||||
context = s.ChatRoles[role].Context
|
||||
context = role.Context
|
||||
}
|
||||
|
||||
if s.DebugMode {
|
||||
@@ -130,7 +145,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
|
||||
// 如果三次请求都失败的话,则返回对应的错误信息
|
||||
if err != nil {
|
||||
replyError(ws)
|
||||
replyError(ws, ErrorMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -155,7 +170,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
err = json.Unmarshal([]byte(line[6:]), &responseBody)
|
||||
if err != nil {
|
||||
logger.Error(line)
|
||||
replyError(ws)
|
||||
replyError(ws, ErrorMsg)
|
||||
break
|
||||
}
|
||||
// 初始化 role
|
||||
@@ -176,7 +191,10 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
}, ws)
|
||||
}
|
||||
}
|
||||
|
||||
// 当前 Token 调用次数减 1
|
||||
if token.MaxCalls > 0 {
|
||||
token.RemainingCalls -= 1
|
||||
}
|
||||
// 追加历史消息
|
||||
context = append(context, types.Message{
|
||||
Role: "user",
|
||||
@@ -190,9 +208,9 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
|
||||
return nil
|
||||
}
|
||||
|
||||
func replyError(ws Client) {
|
||||
func replyError(ws Client, message string) {
|
||||
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: "抱歉,AI 助手开小差了,我马上找人去盘它。"}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws)
|
||||
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user