diff --git a/README.md b/README.md index 6ab730c2..61622b89 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ ## TODOLIST -* [ ] 使用 level DB 保存用户聊天的上下文 +* [ ] 使用 level DB 保存用户聊天记录 * [x] 用户聊天鉴权,设置口令模式 * [ ] 定期清理不在线的会话 sessionID 和聊天上下文记录 -* [ ] 给 Token 设置调用次数 +* [x] 给 Token 设置调用次数 * [x] OpenAI API 负载均衡,限制每个 API Key 每分钟之内调用次数不超过 15次,防止被封 * [x] 角色设定,预设一些角色,比如程序员,客服,作家,老师,艺术家... * [x] markdown 语法解析和代码高亮 diff --git a/main.go b/main.go index d436b959..77c66080 100644 --- a/main.go +++ b/main.go @@ -19,12 +19,6 @@ var configFile string var debugMode bool func main() { - defer func() { - if err := recover(); err != nil { - logger.Error(err) - } - }() - // create config dir configDir, _ := homedir.Expand("~/.config/chat-gpt") _, err := os.Stat(configDir) diff --git a/server/config_handler.go b/server/config_handler.go index 95142d56..7253fcbd 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -109,13 +109,42 @@ func (s *Server) AddToken(c *gin.Context) { return } - err = PutToken(types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}) + token := types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls} + err = PutToken(token) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token}) +} + +// BatchAddToken 批量生成 Token +func (s *Server) BatchAddToken(c *gin.Context) { + var data struct { + Number int `json:"number"` + MaxCalls int `json:"max_calls"` + } + err := json.NewDecoder(c.Request.Body).Decode(&data) + if err != nil || data.MaxCalls <= 0 { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) + return + } + + var tokens = make([]string, 0) + for i := 0; i < data.Number; i++ { + name := utils.RandString(12) + _, err := GetToken(name) + for err == nil { + name = utils.RandString(12) + } + err = PutToken(types.Token{Name: name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}) + if err == nil { + tokens = append(tokens, name) + } + } + + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: tokens}) } func (s *Server) SetToken(c *gin.Context) { @@ -263,8 +292,7 @@ func (s *Server) UpdateChatRole(c *gin.Context) { return } - roles := GetChatRoles() - role := roles[key] + role, err := GetChatRole(key) if role.Key == "" { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Role key not exists"}) return @@ -293,7 +321,7 @@ func (s *Server) UpdateChatRole(c *gin.Context) { } // 保存到 leveldb - err = PutChatRole(role) + err = PutChatRole(*role) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save config"}) return diff --git a/server/server.go b/server/server.go index 5c34d665..1bdb2a8b 100644 --- a/server/server.go +++ b/server/server.go @@ -14,6 +14,7 @@ import ( "openai/utils" "os" "path/filepath" + "runtime/debug" "strings" ) @@ -76,6 +77,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { } engine.Use(sessionMiddleware(s.Config)) engine.Use(AuthorizeMiddleware(s)) + engine.Use(Recover) engine.GET("/hello", Hello) engine.GET("/api/session/get", s.GetSessionHandle) @@ -84,6 +86,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { engine.POST("/api/config/set", s.ConfigSetHandle) engine.GET("/api/config/chat-roles/get", s.GetChatRoleList) engine.POST("api/config/token/add", s.AddToken) + engine.POST("api/config/token/batch-add", s.BatchAddToken) engine.POST("api/config/token/set", s.SetToken) engine.POST("api/config/token/remove", s.RemoveToken) engine.POST("api/config/apikey/add", s.AddApiKey) @@ -115,6 +118,19 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { } +func Recover(c *gin.Context) { + defer func() { + if r := recover(); r != nil { + log.Printf("panic: %v\n", r) + debug.PrintStack() + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) + c.Abort() + } + }() + //加载完 defer recover,继续后续接口调用 + c.Next() +} + func sessionMiddleware(config *types.Config) gin.HandlerFunc { // encrypt the cookie store := cookie.NewStore([]byte(config.Session.SecretKey)) @@ -195,9 +211,9 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc { return } - tokenName := c.GetHeader(types.TokenName) + sessionId := c.GetHeader(types.TokenName) session := sessions.Default(c) - userInfo := session.Get(tokenName) + userInfo := session.Get(sessionId) if userInfo != nil { c.Set(types.SessionKey, userInfo) c.Next()