支持 TOKEN 设置最大调用次数

This commit is contained in:
RockYang
2023-03-27 21:45:02 +08:00
parent a6bab7b12d
commit 5f702d92dc
7 changed files with 192 additions and 82 deletions

View File

@@ -37,10 +37,9 @@ type Server struct {
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
// 防止第三方直接连接 socket 调用 OpenAI API
WsSession map[string]string
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
DebugMode bool // 是否开启调试模式
ChatRoles map[string]types.ChatRole // 保存预设角色信息
ChatSession map[string]types.ChatSession
ApiKeyAccessStat map[string]int64 // 记录每个 API Key 的最后访问之间,保持在 15/min 之内
DebugMode bool // 是否开启调试模式
}
func NewServer(configPath string) (*Server, error) {
@@ -49,18 +48,22 @@ func NewServer(configPath string) (*Server, error) {
if err != nil {
return nil, err
}
roles := GetChatRoles()
if roles == nil {
if len(roles) == 0 { // 初始化默认聊天角色到 leveldb
roles = types.GetDefaultChatRole()
for _, v := range roles {
err := PutChatRole(v)
if err != nil {
return nil, err
}
}
}
return &Server{
Config: config,
ConfigPath: configPath,
ChatContext: make(map[string][]types.Message, 16),
WsSession: make(map[string]string),
ChatSession: make(map[string]types.ChatSession),
ApiKeyAccessStat: make(map[string]int64),
ChatRoles: roles,
}, nil
}
@@ -81,6 +84,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
engine.POST("/api/config/set", s.ConfigSetHandle)
engine.GET("/api/config/chat-roles/get", s.GetChatRoles)
engine.POST("api/config/token/add", s.AddToken)
engine.POST("api/config/token/set", s.SetToken)
engine.POST("api/config/token/remove", s.RemoveToken)
engine.POST("api/config/apikey/add", s.AddApiKey)
engine.POST("api/config/apikey/remove", s.RemoveApiKey)
@@ -182,10 +186,8 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
// WebSocket 连接请求验证
if c.Request.URL.Path == "/api/chat" {
tokenName := c.Query("token")
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
// 每个令牌只能连接一次
//delete(s.WsSession, tokenName)
sessionId := c.Query("sessionId")
if session, ok := s.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
c.Next()
} else {
c.Abort()
@@ -210,9 +212,9 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
}
func (s *Server) GetSessionHandle(c *gin.Context) {
tokenName := c.GetHeader(types.TokenName)
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: addr})
sessionId := c.GetHeader(types.TokenName)
if session, ok := s.ChatSession[sessionId]; ok && session.ClientIP == c.ClientIP() {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
} else {
c.JSON(http.StatusOK, types.BizVo{
Code: types.NotAuthorized,
@@ -243,7 +245,7 @@ func (s *Server) LoginHandle(c *gin.Context) {
logger.Error("Error for save session: ", err)
}
// 记录客户端 IP 地址
s.WsSession[sessionId] = c.ClientIP()
s.ChatSession[sessionId] = types.ChatSession{ClientIP: c.ClientIP(), Token: token, SessionId: sessionId}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: sessionId})
}