From 1244963e8126a736586d556138f4c85d133f526b Mon Sep 17 00:00:00 2001
From: CaIon <1808837298@qq.com>
Date: Mon, 8 Jan 2024 16:23:54 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E5=8F=AF=E8=AE=BE=E7=BD=AE=E4=BB=A4?=
=?UTF-8?q?=E7=89=8C=E8=83=BD=E8=B0=83=E7=94=A8=E7=9A=84=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
controller/token.go | 2 +
middleware/auth.go | 6 ++
middleware/distributor.go | 21 +++++
model/token.go | 56 +++++++++---
web/src/components/TokensTable.js | 15 +++-
web/src/pages/Detail/index.js | 2 +-
web/src/pages/Token/EditToken.js | 142 ++++++++++++++++++++++--------
7 files changed, 192 insertions(+), 52 deletions(-)
diff --git a/controller/token.go b/controller/token.go
index 0f37fd0..157cb2f 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -217,6 +217,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
+ cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
+ cleanToken.ModelLimits = token.ModelLimits
}
err = cleanToken.Update()
if err != nil {
diff --git a/middleware/auth.go b/middleware/auth.go
index c0ff074..e12b81b 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -115,6 +115,12 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
+ if token.ModelLimitsEnabled {
+ c.Set("token_model_limit_enabled", true)
+ c.Set("token_model_limit", token.GetModelLimitsMap())
+ } else {
+ c.Set("token_model_limit_enabled", false)
+ }
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 6d64e30..a70ed41 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -77,6 +77,27 @@ func Distribute() func(c *gin.Context) {
}
}
}
+ // check token model mapping
+ modelLimitEnable := c.GetBool("token_model_limit_enabled")
+ if modelLimitEnable {
+ s, ok := c.Get("token_model_limit")
+ var tokenModelLimit map[string]bool
+ if ok {
+ tokenModelLimit = s.(map[string]bool)
+ } else {
+ tokenModelLimit = map[string]bool{}
+ }
+ if tokenModelLimit != nil {
+ if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
+ abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
+ return
+ }
+ } else {
+ // token model limit is empty, all models are not allowed
+ abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+ return
+ }
+ }
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
diff --git a/model/token.go b/model/token.go
index b5f9230..3c6f35d 100644
--- a/model/token.go
+++ b/model/token.go
@@ -10,17 +10,19 @@ import (
)
type Token struct {
- Id int `json:"id"`
- UserId int `json:"user_id"`
- Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
- Status int `json:"status" gorm:"default:1"`
- Name string `json:"name" gorm:"index" `
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
- ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
- RemainQuota int `json:"remain_quota" gorm:"default:0"`
- UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
- UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
+ Id int `json:"id"`
+ UserId int `json:"user_id"`
+ Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
+ Status int `json:"status" gorm:"default:1"`
+ Name string `json:"name" gorm:"index" `
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
+ RemainQuota int `json:"remain_quota" gorm:"default:0"`
+ UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
+ ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
+ ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
+ UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@@ -107,7 +109,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
- err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
+ err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
return err
}
@@ -122,6 +124,36 @@ func (token *Token) Delete() error {
return err
}
+func (token *Token) IsModelLimitsEnabled() bool {
+ return token.ModelLimitsEnabled
+}
+
+func (token *Token) GetModelLimits() []string {
+ if token.ModelLimits == "" {
+ return []string{}
+ }
+ return strings.Split(token.ModelLimits, ",")
+}
+
+func (token *Token) GetModelLimitsMap() map[string]bool {
+ limits := token.GetModelLimits()
+ limitsMap := make(map[string]bool)
+ for _, limit := range limits {
+ limitsMap[limit] = true
+ }
+ return limitsMap
+}
+
+func DisableModelLimits(tokenId int) error {
+ token, err := GetTokenById(tokenId)
+ if err != nil {
+ return err
+ }
+ token.ModelLimitsEnabled = false
+ token.ModelLimits = ""
+ return token.Update()
+}
+
func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 {
diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js
index aa8fd57..5b685c0 100644
--- a/web/src/components/TokensTable.js
+++ b/web/src/components/TokensTable.js
@@ -43,10 +43,14 @@ function renderTimestamp(timestamp) {
);
}
-function renderStatus(status) {
+function renderStatus(status, model_limits_enabled = false) {
switch (status) {
case 1:
- return