mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
完成 Token 点卡功能
This commit is contained in:
parent
5f702d92dc
commit
d85e91a8da
@ -68,7 +68,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
|
|||||||
}
|
}
|
||||||
|
|
||||||
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
|
if token.MaxCalls > 0 && token.RemainingCalls <= 0 {
|
||||||
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用!")
|
replyError(ws, "当前 TOKEN 点数已经用尽,请充值后再使用或者联系管理员!")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var r = types.ApiRequest{
|
var r = types.ApiRequest{
|
||||||
@ -194,6 +194,7 @@ func (s *Server) sendMessage(session types.ChatSession, role types.ChatRole, tex
|
|||||||
// 当前 Token 调用次数减 1
|
// 当前 Token 调用次数减 1
|
||||||
if token.MaxCalls > 0 {
|
if token.MaxCalls > 0 {
|
||||||
token.RemainingCalls -= 1
|
token.RemainingCalls -= 1
|
||||||
|
_ = PutToken(*token)
|
||||||
}
|
}
|
||||||
// 追加历史消息
|
// 追加历史消息
|
||||||
context = append(context, types.Message{
|
context = append(context, types.Message{
|
||||||
|
@ -88,7 +88,7 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
|
|||||||
|
|
||||||
// AddToken 添加 Token
|
// AddToken 添加 Token
|
||||||
func (s *Server) AddToken(c *gin.Context) {
|
func (s *Server) AddToken(c *gin.Context) {
|
||||||
var data map[string]string
|
var data types.Token
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error decode json data: %s", err.Error())
|
logger.Errorf("Error decode json data: %s", err.Error())
|
||||||
@ -97,30 +97,19 @@ func (s *Server) AddToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 参数处理
|
// 参数处理
|
||||||
var name = data["name"]
|
if data.Name == "" || data.MaxCalls < 0 {
|
||||||
var maxCalls = data["max_calls"]
|
|
||||||
if name == "" || maxCalls == "" {
|
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := strconv.Atoi(maxCalls)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, types.BizVo{
|
|
||||||
Code: types.InvalidParams,
|
|
||||||
Message: "enable_auth must be a int parameter",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查当前要添加的 token 是否已经存在
|
// 检查当前要添加的 token 是否已经存在
|
||||||
_, err = GetToken(name)
|
_, err = GetToken(data.Name)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + name + " already exists"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token " + data.Name + " already exists"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = PutToken(types.Token{Name: name, MaxCalls: n, RemainingCalls: n})
|
err = PutToken(types.Token{Name: data.Name, MaxCalls: data.MaxCalls, RemainingCalls: data.MaxCalls})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||||
return
|
return
|
||||||
@ -130,7 +119,7 @@ func (s *Server) AddToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) SetToken(c *gin.Context) {
|
func (s *Server) SetToken(c *gin.Context) {
|
||||||
var data map[string]string
|
var data types.Token
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error decode json data: %s", err.Error())
|
logger.Errorf("Error decode json data: %s", err.Error())
|
||||||
@ -138,44 +127,35 @@ func (s *Server) SetToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Info(data)
|
||||||
|
|
||||||
// 参数处理
|
// 参数处理
|
||||||
var name = data["name"]
|
if data.Name == "" || data.MaxCalls < 0 {
|
||||||
var maxCalls = data["max_calls"]
|
|
||||||
if name == "" || maxCalls == "" {
|
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Invalid args"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := GetToken(name)
|
token, err := GetToken(data.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Token not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := strconv.Atoi(maxCalls)
|
token.RemainingCalls += data.MaxCalls - token.MaxCalls
|
||||||
if err != nil {
|
token.MaxCalls = data.MaxCalls
|
||||||
c.JSON(http.StatusOK, types.BizVo{
|
|
||||||
Code: types.InvalidParams,
|
|
||||||
Message: "enable_auth must be a int parameter",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token.RemainingCalls += n - token.MaxCalls
|
err = PutToken(*token)
|
||||||
token.MaxCalls = n
|
|
||||||
|
|
||||||
err = PutToken(token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: token})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveToken 删除 Token
|
// RemoveToken 删除 Token
|
||||||
func (s *Server) RemoveToken(c *gin.Context) {
|
func (s *Server) RemoveToken(c *gin.Context) {
|
||||||
var data map[string]string
|
var data types.Token
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
err := json.NewDecoder(c.Request.Body).Decode(&data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error decode json data: %s", err.Error())
|
logger.Errorf("Error decode json data: %s", err.Error())
|
||||||
@ -183,13 +163,11 @@ func (s *Server) RemoveToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if token, ok := data["token"]; ok {
|
err = RemoveToken(data.Name)
|
||||||
err = RemoveToken(token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Failed to save configs"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: GetTokens()})
|
||||||
}
|
}
|
||||||
@ -250,7 +228,7 @@ func (s *Server) ListApiKeys(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetChatRoles(c *gin.Context) {
|
func (s *Server) GetChatRoleList(c *gin.Context) {
|
||||||
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
var rolesOrder = []string{"gpt", "programmer", "teacher", "artist", "philosopher", "lu-xun", "english_trainer", "seller"}
|
||||||
var res = make([]interface{}, 0)
|
var res = make([]interface{}, 0)
|
||||||
var roles = GetChatRoles()
|
var roles = GetChatRoles()
|
||||||
|
30
server/db.go
30
server/db.go
@ -43,14 +43,20 @@ func PutToken(token types.Token) error {
|
|||||||
return db.Put(key, token)
|
return db.Put(key, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetToken(name string) (types.Token, error) {
|
func GetToken(name string) (*types.Token, error) {
|
||||||
key := TokenPrefix + name
|
key := TokenPrefix + name
|
||||||
token, err := db.Get(key)
|
bytes, err := db.Get(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.Token{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return token.(types.Token), nil
|
var token types.Token
|
||||||
|
err = json.Unmarshal(bytes, &token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoveToken(token string) error {
|
func RemoveToken(token string) error {
|
||||||
@ -79,6 +85,22 @@ func PutChatRole(role types.ChatRole) error {
|
|||||||
return db.Put(key, role)
|
return db.Put(key, role)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetChatRole(key string) (*types.ChatRole, error) {
|
||||||
|
key = ChatHistoryPrefix + key
|
||||||
|
bytes, err := db.Get(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var role types.ChatRole
|
||||||
|
err = json.Unmarshal(bytes, &role)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &role, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetChatHistory 获取聊天历史记录
|
// GetChatHistory 获取聊天历史记录
|
||||||
// chat/history/{token}/{role}
|
// chat/history/{token}/{role}
|
||||||
func GetChatHistory() []types.Message {
|
func GetChatHistory() []types.Message {
|
||||||
|
@ -82,7 +82,7 @@ func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
|
|||||||
engine.POST("/api/login", s.LoginHandle)
|
engine.POST("/api/login", s.LoginHandle)
|
||||||
engine.Any("/api/chat", s.ChatHandle)
|
engine.Any("/api/chat", s.ChatHandle)
|
||||||
engine.POST("/api/config/set", s.ConfigSetHandle)
|
engine.POST("/api/config/set", s.ConfigSetHandle)
|
||||||
engine.GET("/api/config/chat-roles/get", s.GetChatRoles)
|
engine.GET("/api/config/chat-roles/get", s.GetChatRoleList)
|
||||||
engine.POST("api/config/token/add", s.AddToken)
|
engine.POST("api/config/token/add", s.AddToken)
|
||||||
engine.POST("api/config/token/set", s.SetToken)
|
engine.POST("api/config/token/set", s.SetToken)
|
||||||
engine.POST("api/config/token/remove", s.RemoveToken)
|
engine.POST("api/config/token/remove", s.RemoveToken)
|
||||||
@ -174,8 +174,8 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/api/config") {
|
if strings.HasPrefix(c.Request.URL.Path, "/api/config") {
|
||||||
accessKey := c.Query("access_key")
|
accessKey := c.GetHeader("ACCESS_KEY")
|
||||||
if accessKey != "RockYang" {
|
if accessKey != s.Config.AccessKey {
|
||||||
c.Abort()
|
c.Abort()
|
||||||
c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"})
|
c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"})
|
||||||
} else {
|
} else {
|
||||||
|
@ -8,8 +8,9 @@ type Config struct {
|
|||||||
Listen string
|
Listen string
|
||||||
Session Session
|
Session Session
|
||||||
ProxyURL []string
|
ProxyURL []string
|
||||||
Chat Chat
|
|
||||||
EnableAuth bool // 是否开启鉴权
|
EnableAuth bool // 是否开启鉴权
|
||||||
|
AccessKey string // 管理员访问 AccessKey, 通过传入这个参数可以访问系统管理 API
|
||||||
|
Chat Chat
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
|
@ -13,6 +13,8 @@ func NewDefaultConfig() *types.Config {
|
|||||||
return &types.Config{
|
return &types.Config{
|
||||||
Listen: "0.0.0.0:5678",
|
Listen: "0.0.0.0:5678",
|
||||||
ProxyURL: make([]string, 0),
|
ProxyURL: make([]string, 0),
|
||||||
|
EnableAuth: true,
|
||||||
|
AccessKey: "yangjian102621@gmail.com",
|
||||||
|
|
||||||
Session: types.Session{
|
Session: types.Session{
|
||||||
SecretKey: RandString(64),
|
SecretKey: RandString(64),
|
||||||
@ -32,7 +34,6 @@ func NewDefaultConfig() *types.Config {
|
|||||||
Temperature: 0.9,
|
Temperature: 0.9,
|
||||||
EnableContext: true,
|
EnableContext: true,
|
||||||
},
|
},
|
||||||
EnableAuth: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,19 +28,13 @@ func (db *LevelDB) Put(key string, value interface{}) error {
|
|||||||
return db.driver.Put([]byte(key), bytes, nil)
|
return db.driver.Put([]byte(key), bytes, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *LevelDB) Get(key string) (interface{}, error) {
|
func (db *LevelDB) Get(key string) ([]byte, error) {
|
||||||
bytes, err := db.driver.Get([]byte(key), nil)
|
bytes, err := db.driver.Get([]byte(key), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var value interface{}
|
return bytes, nil
|
||||||
err = json.Unmarshal(bytes, &value)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return value, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *LevelDB) Search(prefix string) []string {
|
func (db *LevelDB) Search(prefix string) []string {
|
||||||
|
Loading…
Reference in New Issue
Block a user