chore: reorganize billing related package

This commit is contained in:
JustSong
2024-04-06 01:26:48 +08:00
parent cd2707692f
commit 8f4d78e24d
13 changed files with 77 additions and 76 deletions

View File

@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model"
@@ -49,8 +50,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := billing.GetModelRatio(audioModel)
groupRatio := billing.GetGroupRatio(group)
modelRatio := billingratio.GetModelRatio(audioModel)
groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int64
var preConsumedQuota int64
@@ -218,7 +219,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
for k, v := range resp.Header {

View File

@@ -9,7 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model"
@@ -60,12 +60,12 @@ func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := billing.ImageSizeRatios[model][size]
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billing.ImageSizeRatios[model][size]
ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok {
return 1
}
@@ -82,7 +82,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > billing.ImagePromptLengthLimitations[imageRequest.Model] {
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
@@ -165,7 +165,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
return
}
var quota int64
completionRatio := billing.GetCompletionRatio(textRequest.Model)
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))

View File

@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
@@ -20,11 +20,11 @@ import (
)
func isWithinRange(element string, value int) bool {
if _, ok := billing.ImageGenerationAmounts[element]; !ok {
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
return false
}
min := billing.ImageGenerationAmounts[element][0]
max := billing.ImageGenerationAmounts[element][1]
min := billingratio.ImageGenerationAmounts[element][0]
max := billingratio.ImageGenerationAmounts[element][1]
return value >= min && value <= max
}
@@ -87,8 +87,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billing.GetModelRatio(imageRequest.Model)
groupRatio := billing.GetGroupRatio(meta.Group)
modelRatio := billingratio.GetModelRatio(imageRequest.Model)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
@@ -35,8 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billing.GetModelRatio(textRequest.Model)
groupRatio := billing.GetGroupRatio(meta.Group)
modelRatio := billingratio.GetModelRatio(textRequest.Model)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
@@ -87,7 +88,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
}
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -96,7 +97,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
// post-consume quota