diff --git a/service/quota.go b/service/quota.go index 695c073..0ca9145 100644 --- a/service/quota.go +++ b/service/quota.go @@ -9,6 +9,7 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + "strings" "time" ) @@ -20,6 +21,12 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag if err != nil { return err } + + token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.ApiKey, "sk-")) + if err != nil { + return err + } + modelName := relayInfo.UpstreamModelName textInputTokens := usage.InputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens @@ -46,6 +53,10 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) } + if token.RemainQuota < quota { + return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota)) + } + err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) if err != nil { return err