支持 TOKEN 设置最大调用次数

This commit is contained in:
RockYang
2023-03-27 21:45:02 +08:00
parent a6bab7b12d
commit 5f702d92dc
7 changed files with 192 additions and 82 deletions

View File

@@ -16,6 +16,8 @@ import (
"time"
)
const ErrorMsg = "抱歉AI 助手开小差了,我马上找人去盘它。"
// ChatHandle 处理聊天 WebSocket 请求
func (s *Server) ChatHandle(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
@@ -24,16 +26,19 @@ func (s *Server) ChatHandle(c *gin.Context) {
return
}
sessionId := c.Query("sessionId")
role := c.Query("role")
logger.Infof("New websocket connected, IP: %s", c.Request.RemoteAddr)
roleKey := c.Query("role")
session := s.ChatSession[sessionId]
logger.Infof("New websocket connected, IP: %s, Token: %s", c.Request.RemoteAddr, session.Token)
client := NewWsClient(ws)
if !s.ChatRoles[role].Enable { // 角色未启用
var roles = GetChatRoles()
var chatRole = roles[roleKey]
if !chatRole.Enable { // 角色未启用
c.Abort()
return
}
// 发送打招呼信息
replyMessage(types.WsMessage{Type: types.WsStart, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: s.ChatRoles[role].HelloMsg, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: chatRole.HelloMsg, IsHelloMsg: true}, client)
replyMessage(types.WsMessage{Type: types.WsEnd, IsHelloMsg: true}, client)
go func() {
for {
@@ -46,7 +51,7 @@ func (s *Server) ChatHandle(c *gin.Context) {
logger.Info("Receive a message: ", string(message))
// TODO: 当前只保持当前会话的上下文,部保存用户的所有的聊天历史记录,后期要考虑保存所有的历史记录
err = s.sendMessage(sessionId, role, string(message), client)
err = s.sendMessage(session, chatRole, string(message), client)
if err != nil {
logger.Error(err)
}
@@ -55,7 +60,17 @@ func (s *Server) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (s *Server) sendMessage(sessionId string, role string, text string, ws Client) error {
func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, text string, ws Client) error {
token, err := GetToken(session.Token)
if err != nil {
replyError(ws, "当前 TOKEN 无效,请使用合法的 TOKEN 登录!")
return err
}
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用!")
return nil
}
var r = types.ApiRequest{
Model: s.Config.Chat.Model,
Temperature: s.Config.Chat.Temperature,
@@ -63,11 +78,11 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
Stream: true,
}
var context []types.Message
var key = sessionId + role
var key = session.SessionId + role.Name
if v, ok := s.ChatContext[key]; ok && s.Config.Chat.EnableContext {
context = v
} else {
context = s.ChatRoles[role].Context
context = role.Context
}
if s.DebugMode {
@@ -130,7 +145,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
// 如果三次请求都失败的话,则返回对应的错误信息
if err != nil {
replyError(ws)
replyError(ws, ErrorMsg)
return err
}
@@ -155,7 +170,7 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil {
logger.Error(line)
replyError(ws)
replyError(ws, ErrorMsg)
break
}
// 初始化 role
@@ -176,7 +191,10 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
}, ws)
}
}
// 当前 Token 调用次数减 1
if token.MaxCalls > 0 {
token.RemainingCalls -= 1
}
// 追加历史消息
context = append(context, types.Message{
Role: "user",
@@ -190,9 +208,9 @@ func (s *Server) sendMessage(sessionId string, role string, text string, ws Clie
return nil
}
func replyError(ws Client) {
func replyError(ws Client, message string) {
replyMessage(types.WsMessage{Type: types.WsStart}, ws)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: "抱歉AI 助手开小差了,我马上找人去盘它。"}, ws)
replyMessage(types.WsMessage{Type: types.WsMiddle, Content: message}, ws)
replyMessage(types.WsMessage{Type: types.WsEnd}, ws)
}

View File

@@ -96,6 +96,7 @@ func (s *Server) AddToken(c *gin.Context) {
return
}
// 参数处理
var name = data["name"]
var maxCalls = data["max_calls"]
if name == "" || maxCalls == "" {
@@ -112,8 +113,9 @@ func (s *Server) AddToken(c *gin.Context) {
return
}
var tokens = GetTokens()
if utils.ContainToken(tokens, name) {
// 检查当前要添加的 token 是否已经存在
_, err = GetToken(name)
if err == nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"})
return
}
@@ -127,6 +129,50 @@ func (s *Server) AddToken(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
}
func (s *Server) SetToken(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
logger.Errorf("Error decode json data: %s", err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
// 参数处理
var name = data["name"]
var maxCalls = data["max_calls"]
if name == "" || maxCalls == "" {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return
}
token, err := GetToken(name)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"})
return
}
n, err := strconv.Atoi(maxCalls)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "enable_auth must be a int parameter",
})
return
}
token.RemainingCalls += n - token.MaxCalls
token.MaxCalls = n
err = PutToken(token)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
}
// RemoveToken 删除 Token
func (s *Server) RemoveToken(c *gin.Context) {
var data map[string]string
@@ -205,22 +251,23 @@ func (s *Server) ListApiKeys(c *gin.Context) {
}
func (s *Server) GetChatRoles(c *gin.Context) {
//var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
//var roles = make([]interface{}, 0)
//for _, k := range rolesOrder {
// if v, ok := s.Config.ChatRoles[k]; ok && v.Enable {
// roles = append(roles, struct {
// Key string `json:"key"`
// Name string `json:"name"`
// Icon string `json:"icon"`
// }{
// Key: v.Key,
// Name: v.Name,
// Icon: v.Icon,
// })
// }
//}
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: roles})
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
var res = make([]interface{}, 0)
var roles = GetChatRoles()
for _, k := range rolesOrder {
if v, ok := roles[k]; ok && v.Enable {
res = append(res, struct {
Key string `json:"key"`
Name string `json:"name"`
Icon string `json:"icon"`
}{
Key: v.Key,
Name: v.Name,
Icon: v.Icon,
})
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: res})
}
// UpdateChatRole 更新某个聊天角色信息,这里只允许更改名称以及启用和禁用角色操作
@@ -238,39 +285,43 @@ func (s *Server) UpdateChatRole(c *gin.Context) {
return
}
//role := s.Config.ChatRoles[key]
//if enable, ok := data["enable"]; ok {
// v, err := strconv.ParseBool(enable)
// if err != nil {
// c.JSON(http.StatusOK, types.BizVo{
// Code: types.InvalidParams,
// Message: "enable must be a bool parameter",
// })
// return
// }
// role.Enable = v
//}
roles := GetChatRoles()
role := roles[key]
if role.Key == "" {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Role key not exists"})
return
}
//if name, ok := data["name"]; ok {
// role.Name = name
//}
//if helloMsg, ok := data["hello_msg"]; ok {
// role.HelloMsg = helloMsg
//}
//if icon, ok := data["icon"]; ok {
// role.Icon = icon
//}
//
//s.Config.ChatRoles[key] = role
//
//// 保存配置文件
//err = types.SaveConfig(s.Config, s.ConfigPath)
//if err != nil {
// c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config file"})
// return
//}
//
//c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
if enable, ok := data["enable"]; ok {
v, err := strconv.ParseBool(enable)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "enable must be a bool parameter",
})
return
}
role.Enable = v
}
if name, ok := data["name"]; ok {
role.Name = name
}
if helloMsg, ok := data["hello_msg"]; ok {
role.HelloMsg = helloMsg
}
if icon, ok := data["icon"]; ok {
role.Icon = icon
}
// 保存到 leveldb
err = PutChatRole(role)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config"})
return
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: role})
}
// AddProxy 添加一个代理

View File

@@ -43,6 +43,16 @@ func PutToken(token types.Token) error {
return db.Put(key, token)
}
func GetToken(name string) (types.Token, error) {
key := TokenPrefix + name
token, err := db.Get(key)
if err != nil {
return types.Token{}, err
}
return token.(types.Token), nil
}
func RemoveToken(token string) error {
key := TokenPrefix + token
return db.Delete(key)
@@ -51,7 +61,22 @@ func RemoveToken(token string) error {
// GetChatRoles 获取聊天角色
// chat/roles
func GetChatRoles() map[string]types.ChatRole {
return nil
items := db.Search(ChatRolePrefix)
var roles = make(map[string]types.ChatRole)
for _, v := range items {
var role types.ChatRole
err := json.Unmarshal([]byte(v), &role)
if err != nil {
continue
}
roles[role.Key] = role
}
return roles
}
func PutChatRole(role types.ChatRole) error {
key := ChatRolePrefix + role.Key
return db.Put(key, role)
}
// GetChatHistory 获取聊天历史记录

View File

@@ -37,10 +37,9 @@ type Server struct {
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
WsSession map[string]string
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
DebugMode bool // 是否开启调试模式
ChatRoles map[string]types.ChatRole // 保存预设角色信息
ChatSession map[string]types.ChatSession
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
DebugMode bool // 是否开启调试模式
}
func NewServer(configPath string) (*Server, error) {
@@ -49,18 +48,22 @@ func NewServer(configPath string) (*Server, error) {
if err != nil {
return nil, err
}
roles := GetChatRoles()
if roles == nil {
if len(roles) == 0 { // 初始化默认聊天角色到 leveldb
roles = types.GetDefaultChatRole()
for _, v := range roles {
err := PutChatRole(v)
if err != nil {
return nil, err
}
}
}
return &Server{
Config: config,
ConfigPath: configPath,
ChatContext: make(map[string][]types.Message, 16),
WsSession: make(map[string]string),
ChatSession: make(map[string]types.ChatSession),
ApiKeyAccessStat: make(map[string]int64),
ChatRoles: roles,
}, nil
}
@@ -81,6 +84,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
engine.POST("/api/config/set", s.ConfigSetHandle)
engine.GET("/api/config/chat-roles/get", s.GetChatRoles)
engine.POST("api/config/token/add", s.AddToken)
engine.POST("api/config/token/set", s.SetToken)
engine.POST("api/config/token/remove", s.RemoveToken)
engine.POST("api/config/apikey/add", s.AddApiKey)
engine.POST("api/config/apikey/remove", s.RemoveApiKey)
@@ -182,10 +186,8 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
// WebSocket 连接请求验证
if c.Request.URL.Path == "/api/chat" {
tokenName := c.Query("token")
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
// 每个令牌只能连接一次
//delete(s.WsSession, tokenName)
sessionId := c.Query("sessionId")
if session, ok := s.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
c.Next()
} else {
c.Abort()
@@ -210,9 +212,9 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
}
func (s *Server) GetSessionHandle(c *gin.Context) {
tokenName := c.GetHeader(types.TokenName)
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: addr})
sessionId := c.GetHeader(types.TokenName)
if session, ok := s.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
} else {
c.JSON(http.StatusOK, types.BizVo{
Code: types.NotAuthorized,
@@ -243,7 +245,7 @@ func (s *Server) LoginHandle(c *gin.Context) {
logger.Error("Error for save session: ", err)
}
// 记录客户端 IP 地址
s.WsSession[sessionId] = c.ClientIP()
s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Token: token, SessionId: sessionId}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId})
}