From 8afdc56b11987c79f659c56d932979468e976706 Mon Sep 17 00:00:00 2001 From: JustSong Date: Tue, 16 May 2023 13:29:22 +0800 Subject: [PATCH] fix: fix quota not consuming --- controller/relay.go | 10 ++++++++- middleware/auth.go | 2 +- model/token.go | 53 ++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 5da9c19f..99225a9d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -128,6 +128,13 @@ func relayHelper(c *gin.Context) error { model_ = strings.TrimSuffix(model_, "-0314") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) } + preConsumedQuota := 500 // TODO: make this configurable, take ratio into account + if consumeQuota { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return err + } + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) if err != nil { return err @@ -179,7 +186,8 @@ func relayHelper(c *gin.Context) error { } ratio := common.GetModelRatio(textRequest.Model) quota = int(float64(quota) * ratio) - err := model.DecreaseTokenQuota(tokenId, quota) + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { common.SysError("Error consuming token remain quota: " + err.Error()) } diff --git a/middleware/auth.go b/middleware/auth.go index 79077f97..bf09753a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -111,7 +111,7 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) requestURL := c.Request.URL.String() - consumeQuota := !token.UnlimitedQuota + consumeQuota := true if strings.HasPrefix(requestURL, "/v1/models") { consumeQuota = false } diff --git a/model/token.go b/model/token.go index f24330d1..ada67255 100644 --- a/model/token.go +++ b/model/token.go @@ -130,7 +130,23 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func DecreaseTokenQuota(tokenId int, quota int) (err error) { +func IncreaseTokenQuota(id int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error + return err +} + +func DecreaseTokenQuota(id int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error + return err +} + +func PreConsumeTokenQuota(tokenId int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -138,7 +154,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) { if err != nil { return err } - if token.RemainQuota < quota { + if !token.UnlimitedQuota && token.RemainQuota < quota { return errors.New("令牌额度不足") } userQuota, err := GetUserQuota(token.UserId) @@ -163,17 +179,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) { if email != "" { topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) err = common.SendEmail(prompt, email, - fmt.Sprintf("%s,剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota-quota, topUpLink, topUpLink)) + fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { common.SysError("发送邮件失败:" + err.Error()) } } }() } - err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error - if err != nil { - return err + if !token.UnlimitedQuota { + err = DecreaseTokenQuota(tokenId, quota) + if err != nil { + return err + } } err = DecreaseUserQuota(token.UserId, quota) return err } + +func PostConsumeTokenQuota(tokenId int, quota int) (err error) { + token, err := GetTokenById(tokenId) + if quota > 0 { + err = DecreaseUserQuota(token.UserId, quota) + } else { + err = IncreaseUserQuota(token.UserId, -quota) + } + if err != nil { + return err + } + if !token.UnlimitedQuota { + if quota > 0 { + err = DecreaseTokenQuota(tokenId, quota) + } else { + err = IncreaseTokenQuota(tokenId, -quota) + } + if err != nil { + return err + } + } + return nil +}