From d85e91a8da23995137d97ac0d071d1de5f189e7f Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 28 Mar 2023 10:17:36 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=20Token=20=E7=82=B9=E5=8D=A1?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat_handler.go | 3 +- server/config_handler.go | 62 +++++++++++++--------------------------- server/db.go | 30 ++++++++++++++++--- server/server.go | 6 ++-- types/config.go | 3 +- utils/config.go | 7 +++-- utils/leveldb.go | 10 ++----- 7 files changed, 59 insertions(+), 62 deletions(-) diff --git a/server/chat_handler.go b/server/chat_handler.go index b42678ea..1017e76b 100644 --- a/server/chat_handler.go +++ b/server/chat_handler.go @@ -68,7 +68,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex } if token.MaxCalls > 0 && token.RemainingCalls <= 0 { - replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用!") + replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!") return nil } var r = types.ApiRequest{ @@ -194,6 +194,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex // 当前 Token 调用次数减 1 if token.MaxCalls > 0 { token.RemainingCalls -= 1 + _ = PutToken(*token) } // 追加历史消息 context = append(context, types.Message{ diff --git a/server/config_handler.go b/server/config_handler.go index 37058abb..95142d56 100644 --- a/server/config_handler.go +++ b/server/config_handler.go @@ -88,7 +88,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) { // AddToken 添加 Token func (s *Server) AddToken(c *gin.Context) { - var data map[string]string + var data types.Token err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -97,30 +97,19 @@ func (s *Server) AddToken(c *gin.Context) { } // 参数处理 - var name = data["name"] - var maxCalls = data["max_calls"] - if name == "" || maxCalls == "" { + if data.Name == "" || data.MaxCalls < 0 { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) return } - n, err := strconv.Atoi(maxCalls) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "enable_auth must be a int parameter", - }) - return - } - // 检查当前要添加的 token 是否已经存在 - _, err = GetToken(name) + _, err = GetToken(data.Name) if err == nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + data.Name + " already exists"}) return } - err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n}) + err = PutToken(types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls}) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return @@ -130,7 +119,7 @@ func (s *Server) AddToken(c *gin.Context) { } func (s *Server) SetToken(c *gin.Context) { - var data map[string]string + var data types.Token err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -138,44 +127,35 @@ func (s *Server) SetToken(c *gin.Context) { return } + logger.Info(data) + // 参数处理 - var name = data["name"] - var maxCalls = data["max_calls"] - if name == "" || maxCalls == "" { + if data.Name == "" || data.MaxCalls < 0 { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) return } - token, err := GetToken(name) + token, err := GetToken(data.Name) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"}) return } - n, err := strconv.Atoi(maxCalls) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{ - Code: types.InvalidParams, - Message: "enable_auth must be a int parameter", - }) - return - } + token.RemainingCalls += data.MaxCalls - token.MaxCalls + token.MaxCalls = data.MaxCalls - token.RemainingCalls += n - token.MaxCalls - token.MaxCalls = n - - err = PutToken(token) + err = PutToken(*token) if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return } - c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token}) } // RemoveToken 删除 Token func (s *Server) RemoveToken(c *gin.Context) { - var data map[string]string + var data types.Token err := json.NewDecoder(c.Request.Body).Decode(&data) if err != nil { logger.Errorf("Error decode json data: %s", err.Error()) @@ -183,12 +163,10 @@ func (s *Server) RemoveToken(c *gin.Context) { return } - if token, ok := data["token"]; ok { - err = RemoveToken(token) - if err != nil { - c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) - return - } + err = RemoveToken(data.Name) + if err != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) + return } c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) @@ -250,7 +228,7 @@ func (s *Server) ListApiKeys(c *gin.Context) { c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) } -func (s *Server) GetChatRoles(c *gin.Context) { +func (s *Server) GetChatRoleList(c *gin.Context) { var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"} var res = make([]interface{}, 0) var roles = GetChatRoles() diff --git a/server/db.go b/server/db.go index 6e8aca84..80f6a281 100644 --- a/server/db.go +++ b/server/db.go @@ -43,14 +43,20 @@ func PutToken(token types.Token) error { return db.Put(key, token) } -func GetToken(name string) (types.Token, error) { +func GetToken(name string) (*types.Token, error) { key := TokenPrefix + name - token, err := db.Get(key) + bytes, err := db.Get(key) if err != nil { - return types.Token{}, err + return nil, err } - return token.(types.Token), nil + var token types.Token + err = json.Unmarshal(bytes, &token) + if err != nil { + return nil, err + } + + return &token, nil } func RemoveToken(token string) error { @@ -79,6 +85,22 @@ func PutChatRole(role types.ChatRole) error { return db.Put(key, role) } +func GetChatRole(key string) (*types.ChatRole, error) { + key = ChatHistoryPrefix + key + bytes, err := db.Get(key) + if err != nil { + return nil, err + } + + var role types.ChatRole + err = json.Unmarshal(bytes, &role) + if err != nil { + return nil, err + } + + return &role, nil +} + // GetChatHistory 获取聊天历史记录 // chat/history/{token}/{role} func GetChatHistory() []types.Message { diff --git a/server/server.go b/server/server.go index 6dbbd7af..5c34d665 100644 --- a/server/server.go +++ b/server/server.go @@ -82,7 +82,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) { engine.POST("/api/login", s.LoginHandle) engine.Any("/api/chat", s.ChatHandle) engine.POST("/api/config/set", s.ConfigSetHandle) - engine.GET("/api/config/chat-roles/get", s.GetChatRoles) + engine.GET("/api/config/chat-roles/get", s.GetChatRoleList) engine.POST("api/config/token/add", s.AddToken) engine.POST("api/config/token/set", s.SetToken) engine.POST("api/config/token/remove", s.RemoveToken) @@ -174,8 +174,8 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc { } if strings.HasPrefix(c.Request.URL.Path, "/api/config") { - accessKey := c.Query("access_key") - if accessKey != "RockYang" { + accessKey := c.GetHeader("ACCESS_KEY") + if accessKey != s.Config.AccessKey { c.Abort() c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"}) } else { diff --git a/types/config.go b/types/config.go index 1dbc81a8..3e9174b3 100644 --- a/types/config.go +++ b/types/config.go @@ -8,8 +8,9 @@ type Config struct { Listen string Session Session ProxyURL []string + EnableAuth bool // 是否开启鉴权 + AccessKey string // 管理员访问 AccessKey, 通过传入这个参数可以访问系统管理 API Chat Chat - EnableAuth bool // 是否开启鉴权 } type Token struct { diff --git a/utils/config.go b/utils/config.go index 0495d23a..22048329 100644 --- a/utils/config.go +++ b/utils/config.go @@ -11,8 +11,10 @@ import ( func NewDefaultConfig() *types.Config { return &types.Config{ - Listen: "0.0.0.0:5678", - ProxyURL: make([]string, 0), + Listen: "0.0.0.0:5678", + ProxyURL: make([]string, 0), + EnableAuth: true, + AccessKey: "yangjian102621@gmail.com", Session: types.Session{ SecretKey: RandString(64), @@ -32,7 +34,6 @@ func NewDefaultConfig() *types.Config { Temperature: 0.9, EnableContext: true, }, - EnableAuth: true, } } diff --git a/utils/leveldb.go b/utils/leveldb.go index 83d0a922..ae839dc4 100644 --- a/utils/leveldb.go +++ b/utils/leveldb.go @@ -28,19 +28,13 @@ func (db *LevelDB) Put(key string, value interface{}) error { return db.driver.Put([]byte(key), bytes, nil) } -func (db *LevelDB) Get(key string) (interface{}, error) { +func (db *LevelDB) Get(key string) ([]byte, error) { bytes, err := db.driver.Get([]byte(key), nil) if err != nil { return nil, err } - var value interface{} - err = json.Unmarshal(bytes, &value) - if err != nil { - return nil, err - } - - return value, nil + return bytes, nil } func (db *LevelDB) Search(prefix string) []string {