diff --git a/controller/relay.go b/controller/relay.go
index 5da9c19f..99225a9d 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -128,6 +128,13 @@ func relayHelper(c *gin.Context) error {
model_ = strings.TrimSuffix(model_, "-0314")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
+ preConsumedQuota := 500 // TODO: make this configurable, take ratio into account
+ if consumeQuota {
+ err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+ if err != nil {
+ return err
+ }
+ }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
if err != nil {
return err
@@ -179,7 +186,8 @@ func relayHelper(c *gin.Context) error {
}
ratio := common.GetModelRatio(textRequest.Model)
quota = int(float64(quota) * ratio)
- err := model.DecreaseTokenQuota(tokenId, quota)
+ quotaDelta := quota - preConsumedQuota
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("Error consuming token remain quota: " + err.Error())
}
diff --git a/middleware/auth.go b/middleware/auth.go
index 79077f97..bf09753a 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -111,7 +111,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
requestURL := c.Request.URL.String()
- consumeQuota := !token.UnlimitedQuota
+ consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
diff --git a/model/token.go b/model/token.go
index f24330d1..ada67255 100644
--- a/model/token.go
+++ b/model/token.go
@@ -130,7 +130,23 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
-func DecreaseTokenQuota(tokenId int, quota int) (err error) {
+func IncreaseTokenQuota(id int, quota int) (err error) {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
+ return err
+}
+
+func DecreaseTokenQuota(id int, quota int) (err error) {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
+ return err
+}
+
+func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -138,7 +154,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
if err != nil {
return err
}
- if token.RemainQuota < quota {
+ if !token.UnlimitedQuota && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
userQuota, err := GetUserQuota(token.UserId)
@@ -163,17 +179,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email,
- fmt.Sprintf("%s,剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota-quota, topUpLink, topUpLink))
+ fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("发送邮件失败:" + err.Error())
}
}
}()
}
- err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
- if err != nil {
- return err
+ if !token.UnlimitedQuota {
+ err = DecreaseTokenQuota(tokenId, quota)
+ if err != nil {
+ return err
+ }
}
err = DecreaseUserQuota(token.UserId, quota)
return err
}
+
+func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
+ token, err := GetTokenById(tokenId)
+ if quota > 0 {
+ err = DecreaseUserQuota(token.UserId, quota)
+ } else {
+ err = IncreaseUserQuota(token.UserId, -quota)
+ }
+ if err != nil {
+ return err
+ }
+ if !token.UnlimitedQuota {
+ if quota > 0 {
+ err = DecreaseTokenQuota(tokenId, quota)
+ } else {
+ err = IncreaseTokenQuota(tokenId, -quota)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}