完成 Token 点卡功能

This commit is contained in:
RockYang 2023-03-28 10:17:36 +08:00
parent 5f702d92dc
commit d85e91a8da
7 changed files with 59 additions and 62 deletions

View File

@ -68,7 +68,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
} }
if token.MaxCalls > 0 && token.RemainingCalls <= 0 { if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用") replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员")
return nil return nil
} }
var r = types.ApiRequest{ var r = types.ApiRequest{
@ -194,6 +194,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
// 当前 Token 调用次数减 1 // 当前 Token 调用次数减 1
if token.MaxCalls > 0 { if token.MaxCalls > 0 {
token.RemainingCalls -= 1 token.RemainingCalls -= 1
_ = PutToken(*token)
} }
// 追加历史消息 // 追加历史消息
context = append(context, types.Message{ context = append(context, types.Message{

View File

@ -88,7 +88,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
// AddToken 添加 Token // AddToken 添加 Token
func (s *Server) AddToken(c *gin.Context) { func (s *Server) AddToken(c *gin.Context) {
var data map[string]string var data types.Token
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
logger.Errorf("Error decode json data: %s", err.Error()) logger.Errorf("Error decode json data: %s", err.Error())
@ -97,30 +97,19 @@ func (s *Server) AddToken(c *gin.Context) {
} }
// 参数处理 // 参数处理
var name = data["name"] if data.Name == "" || data.MaxCalls < 0 {
var maxCalls = data["max_calls"]
if name == "" || maxCalls == "" {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return return
} }
n, err := strconv.Atoi(maxCalls)
if err != nil {
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "enable_auth must be a int parameter",
})
return
}
// 检查当前要添加的 token 是否已经存在 // 检查当前要添加的 token 是否已经存在
_, err = GetToken(name) _, err = GetToken(data.Name)
if err == nil { if err == nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + data.Name + " already exists"})
return return
} }
err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n}) err = PutToken(types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls})
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return return
@ -130,7 +119,7 @@ func (s *Server) AddToken(c *gin.Context) {
} }
func (s *Server) SetToken(c *gin.Context) { func (s *Server) SetToken(c *gin.Context) {
var data map[string]string var data types.Token
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
logger.Errorf("Error decode json data: %s", err.Error()) logger.Errorf("Error decode json data: %s", err.Error())
@ -138,44 +127,35 @@ func (s *Server) SetToken(c *gin.Context) {
return return
} }
logger.Info(data)
// 参数处理 // 参数处理
var name = data["name"] if data.Name == "" || data.MaxCalls < 0 {
var maxCalls = data["max_calls"]
if name == "" || maxCalls == "" {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
return return
} }
token, err := GetToken(name) token, err := GetToken(data.Name)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"})
return return
} }
n, err := strconv.Atoi(maxCalls) token.RemainingCalls += data.MaxCalls - token.MaxCalls
if err != nil { token.MaxCalls = data.MaxCalls
c.JSON(http.StatusOK, types.BizVo{
Code: types.InvalidParams,
Message: "enable_auth must be a int parameter",
})
return
}
token.RemainingCalls += n - token.MaxCalls err = PutToken(*token)
token.MaxCalls = n
err = PutToken(token)
if err != nil { if err != nil {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
return return
} }
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token})
} }
// RemoveToken 删除 Token // RemoveToken 删除 Token
func (s *Server) RemoveToken(c *gin.Context) { func (s *Server) RemoveToken(c *gin.Context) {
var data map[string]string var data types.Token
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
logger.Errorf("Error decode json data: %s", err.Error()) logger.Errorf("Error decode json data: %s", err.Error())
@ -183,12 +163,10 @@ func (s *Server) RemoveToken(c *gin.Context) {
return return
} }
if token, ok := data["token"]; ok { err = RemoveToken(data.Name)
err = RemoveToken(token) if err != nil {
if err != nil { c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"}) return
return
}
} }
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
@ -250,7 +228,7 @@ func (s *Server) ListApiKeys(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
} }
func (s *Server) GetChatRoles(c *gin.Context) { func (s *Server) GetChatRoleList(c *gin.Context) {
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"} var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
var res = make([]interface{}, 0) var res = make([]interface{}, 0)
var roles = GetChatRoles() var roles = GetChatRoles()

View File

@ -43,14 +43,20 @@ func PutToken(token types.Token) error {
return db.Put(key, token) return db.Put(key, token)
} }
func GetToken(name string) (types.Token, error) { func GetToken(name string) (*types.Token, error) {
key := TokenPrefix + name key := TokenPrefix + name
token, err := db.Get(key) bytes, err := db.Get(key)
if err != nil { if err != nil {
return types.Token{}, err return nil, err
} }
return token.(types.Token), nil var token types.Token
err = json.Unmarshal(bytes, &token)
if err != nil {
return nil, err
}
return &token, nil
} }
func RemoveToken(token string) error { func RemoveToken(token string) error {
@ -79,6 +85,22 @@ func PutChatRole(role types.ChatRole) error {
return db.Put(key, role) return db.Put(key, role)
} }
func GetChatRole(key string) (*types.ChatRole, error) {
key = ChatHistoryPrefix + key
bytes, err := db.Get(key)
if err != nil {
return nil, err
}
var role types.ChatRole
err = json.Unmarshal(bytes, &role)
if err != nil {
return nil, err
}
return &role, nil
}
// GetChatHistory 获取聊天历史记录 // GetChatHistory 获取聊天历史记录
// chat/history/{token}/{role} // chat/history/{token}/{role}
func GetChatHistory() []types.Message { func GetChatHistory() []types.Message {

View File

@ -82,7 +82,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
engine.POST("/api/login", s.LoginHandle) engine.POST("/api/login", s.LoginHandle)
engine.Any("/api/chat", s.ChatHandle) engine.Any("/api/chat", s.ChatHandle)
engine.POST("/api/config/set", s.ConfigSetHandle) engine.POST("/api/config/set", s.ConfigSetHandle)
engine.GET("/api/config/chat-roles/get", s.GetChatRoles) engine.GET("/api/config/chat-roles/get", s.GetChatRoleList)
engine.POST("api/config/token/add", s.AddToken) engine.POST("api/config/token/add", s.AddToken)
engine.POST("api/config/token/set", s.SetToken) engine.POST("api/config/token/set", s.SetToken)
engine.POST("api/config/token/remove", s.RemoveToken) engine.POST("api/config/token/remove", s.RemoveToken)
@ -174,8 +174,8 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
} }
if strings.HasPrefix(c.Request.URL.Path, "/api/config") { if strings.HasPrefix(c.Request.URL.Path, "/api/config") {
accessKey := c.Query("access_key") accessKey := c.GetHeader("ACCESS_KEY")
if accessKey != "RockYang" { if accessKey != s.Config.AccessKey {
c.Abort() c.Abort()
c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"}) c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"})
} else { } else {

View File

@ -8,8 +8,9 @@ type Config struct {
Listen string Listen string
Session Session Session Session
ProxyURL []string ProxyURL []string
EnableAuth bool // 是否开启鉴权
AccessKey string // 管理员访问 AccessKey, 通过传入这个参数可以访问系统管理 API
Chat Chat Chat Chat
EnableAuth bool // 是否开启鉴权
} }
type Token struct { type Token struct {

View File

@ -11,8 +11,10 @@ import (
func NewDefaultConfig() *types.Config { func NewDefaultConfig() *types.Config {
return &types.Config{ return &types.Config{
Listen: "0.0.0.0:5678", Listen: "0.0.0.0:5678",
ProxyURL: make([]string, 0), ProxyURL: make([]string, 0),
EnableAuth: true,
AccessKey: "yangjian102621@gmail.com",
Session: types.Session{ Session: types.Session{
SecretKey: RandString(64), SecretKey: RandString(64),
@ -32,7 +34,6 @@ func NewDefaultConfig() *types.Config {
Temperature: 0.9, Temperature: 0.9,
EnableContext: true, EnableContext: true,
}, },
EnableAuth: true,
} }
} }

View File

@ -28,19 +28,13 @@ func (db *LevelDB) Put(key string, value interface{}) error {
return db.driver.Put([]byte(key), bytes, nil) return db.driver.Put([]byte(key), bytes, nil)
} }
func (db *LevelDB) Get(key string) (interface{}, error) { func (db *LevelDB) Get(key string) ([]byte, error) {
bytes, err := db.driver.Get([]byte(key), nil) bytes, err := db.driver.Get([]byte(key), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var value interface{} return bytes, nil
err = json.Unmarshal(bytes, &value)
if err != nil {
return nil, err
}
return value, nil
} }
func (db *LevelDB) Search(prefix string) []string { func (db *LevelDB) Search(prefix string) []string {