♻️ refactor: Refactor price module (#123) (#109) (#128)

This commit is contained in:
Buer
2024-03-28 16:53:34 +08:00
committed by GitHub
parent 646cb74154
commit a58e538c26
32 changed files with 2361 additions and 663 deletions

402
relay/util/pricing.go Normal file
View File

@@ -0,0 +1,402 @@
package util
import (
"encoding/json"
"errors"
"one-api/common"
"one-api/model"
"sort"
"strings"
"sync"
"github.com/spf13/viper"
)
// PricingInstance is the Pricing instance
var PricingInstance *Pricing
// Pricing is a struct that contains the pricing data
type Pricing struct {
sync.RWMutex
Prices map[string]*model.Price `json:"models"`
Match []string `json:"-"`
}
type BatchPrices struct {
Models []string `json:"models" binding:"required"`
Price model.Price `json:"price" binding:"required"`
}
// NewPricing creates a new Pricing instance
func NewPricing() {
common.SysLog("Initializing Pricing")
PricingInstance = &Pricing{
Prices: make(map[string]*model.Price),
Match: make([]string, 0),
}
err := PricingInstance.Init()
if err != nil {
common.SysError("Failed to initialize Pricing:" + err.Error())
return
}
// 初始化时,需要检测是否有更新
if viper.GetBool("auto_price_updates") {
common.SysLog("Checking for pricing updates")
prices := model.GetDefaultPrice()
PricingInstance.SyncPricing(prices, false)
common.SysLog("Pricing initialized")
}
}
// initializes the Pricing instance
func (p *Pricing) Init() error {
prices, err := model.GetAllPrices()
if err != nil {
return err
}
if len(prices) == 0 {
return nil
}
newPrices := make(map[string]*model.Price)
newMatch := make(map[string]bool)
for _, price := range prices {
newPrices[price.Model] = price
if strings.HasSuffix(price.Model, "*") {
if _, ok := newMatch[price.Model]; !ok {
newMatch[price.Model] = true
}
}
}
var newMatchList []string
for match := range newMatch {
newMatchList = append(newMatchList, match)
}
p.Lock()
defer p.Unlock()
p.Prices = newPrices
p.Match = newMatchList
return nil
}
// GetPrice returns the price of a model
func (p *Pricing) GetPrice(modelName string) *model.Price {
p.RLock()
defer p.RUnlock()
if price, ok := p.Prices[modelName]; ok {
return price
}
matchModel := common.GetModelsWithMatch(&p.Match, modelName)
if price, ok := p.Prices[matchModel]; ok {
return price
}
return &model.Price{
Type: model.TokensPriceType,
ChannelType: common.ChannelTypeUnknown,
Input: model.DefaultPrice,
Output: model.DefaultPrice,
}
}
func (p *Pricing) GetAllPrices() map[string]*model.Price {
return p.Prices
}
func (p *Pricing) GetAllPricesList() []*model.Price {
var prices []*model.Price
for _, price := range p.Prices {
prices = append(prices, price)
}
return prices
}
func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error {
if _, ok := p.Prices[modelName]; !ok {
return errors.New("model not found")
}
if _, ok := p.Prices[price.Model]; modelName != price.Model && ok {
return errors.New("model names cannot be duplicated")
}
if err := p.deleteRawPrice(modelName); err != nil {
return err
}
return price.Insert()
}
// UpdatePrice updates the price of a model
func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error {
if err := p.updateRawPrice(modelName, price); err != nil {
return err
}
err := p.Init()
return err
}
func (p *Pricing) addRawPrice(price *model.Price) error {
if _, ok := p.Prices[price.Model]; ok {
return errors.New("model already exists")
}
return price.Insert()
}
// AddPrice adds a new price to the Pricing instance
func (p *Pricing) AddPrice(price *model.Price) error {
if err := p.addRawPrice(price); err != nil {
return err
}
err := p.Init()
return err
}
func (p *Pricing) deleteRawPrice(modelName string) error {
item, ok := p.Prices[modelName]
if !ok {
return errors.New("model not found")
}
return item.Delete()
}
// DeletePrice deletes a price from the Pricing instance
func (p *Pricing) DeletePrice(modelName string) error {
if err := p.deleteRawPrice(modelName); err != nil {
return err
}
err := p.Init()
return err
}
// SyncPricing syncs the pricing data
func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error {
var err error
if overwrite {
err = p.SyncPriceWithOverwrite(pricing)
} else {
err = p.SyncPriceWithoutOverwrite(pricing)
}
return err
}
// SyncPriceWithOverwrite syncs the pricing data with overwrite
func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error {
tx := model.DB.Begin()
err := model.DeleteAllPrices(tx)
if err != nil {
tx.Rollback()
return err
}
err = model.InsertPrices(tx, pricing)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return p.Init()
}
// SyncPriceWithoutOverwrite syncs the pricing data without overwrite
func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error {
var newPrices []*model.Price
for _, price := range pricing {
if _, ok := p.Prices[price.Model]; !ok {
newPrices = append(newPrices, price)
}
}
if len(newPrices) == 0 {
return nil
}
tx := model.DB.Begin()
err := model.InsertPrices(tx, newPrices)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return p.Init()
}
// BatchDeletePrices deletes the prices of multiple models
func (p *Pricing) BatchDeletePrices(models []string) error {
tx := model.DB.Begin()
err := model.DeletePrices(tx, models)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
p.Lock()
defer p.Unlock()
for _, model := range models {
delete(p.Prices, model)
}
return nil
}
func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error {
// 查找需要删除的model
var deletePrices []string
var addPrices []*model.Price
var updatePrices []string
for _, model := range originalModels {
if !common.Contains(model, batchPrices.Models) {
deletePrices = append(deletePrices, model)
} else {
updatePrices = append(updatePrices, model)
}
}
for _, model := range batchPrices.Models {
if !common.Contains(model, originalModels) {
addPrice := batchPrices.Price
addPrice.Model = model
addPrices = append(addPrices, &addPrice)
}
}
tx := model.DB.Begin()
if len(addPrices) > 0 {
err := model.InsertPrices(tx, addPrices)
if err != nil {
tx.Rollback()
return err
}
}
if len(updatePrices) > 0 {
err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price)
if err != nil {
tx.Rollback()
return err
}
}
if len(deletePrices) > 0 {
err := model.DeletePrices(tx, deletePrices)
if err != nil {
tx.Rollback()
return err
}
}
tx.Commit()
return p.Init()
}
func GetPricesList(pricingType string) []*model.Price {
var prices []*model.Price
switch pricingType {
case "default":
prices = model.GetDefaultPrice()
case "db":
prices = PricingInstance.GetAllPricesList()
case "old":
prices = GetOldPricesList()
default:
return nil
}
sort.Slice(prices, func(i, j int) bool {
if prices[i].ChannelType == prices[j].ChannelType {
return prices[i].Model < prices[j].Model
}
return prices[i].ChannelType < prices[j].ChannelType
})
return prices
}
func GetOldPricesList() []*model.Price {
oldDataJson, err := model.GetOption("ModelRatio")
if err != nil || oldDataJson.Value == "" {
return nil
}
oldData := make(map[string][]float64)
err = json.Unmarshal([]byte(oldDataJson.Value), &oldData)
if err != nil {
return nil
}
var prices []*model.Price
for modelName, oldPrice := range oldData {
price := PricingInstance.GetPrice(modelName)
prices = append(prices, &model.Price{
Model: modelName,
Type: model.TokensPriceType,
ChannelType: price.ChannelType,
Input: oldPrice[0],
Output: oldPrice[1],
})
}
return prices
}
// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices {
// batchPricesMap := make(map[string]*BatchPrices)
// for _, price := range prices {
// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output)
// batchPrice, exists := batchPricesMap[key]
// if exists {
// batchPrice.Models = append(batchPrice.Models, price.Model)
// } else {
// batchPricesMap[key] = &BatchPrices{
// Models: []string{price.Model},
// Price: *price,
// }
// }
// }
// var batchPrices []*BatchPrices
// for _, batchPrice := range batchPricesMap {
// batchPrices = append(batchPrices, batchPrice)
// }
// return batchPrices
// }

View File

@@ -15,17 +15,16 @@ import (
)
type Quota struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio []float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
modelName string
promptTokens int
price model.Price
groupRatio float64
inputRatio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) {
@@ -37,7 +36,16 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quota.init(c.GetString("group"))
quota.price = *PricingInstance.GetPrice(quota.modelName)
quota.groupRatio = common.GetGroupRatio(c.GetString("group"))
quota.inputRatio = quota.price.GetInput() * quota.groupRatio
if quota.price.Type == model.TimesPriceType {
quota.preConsumedQuota = int(1000 * quota.inputRatio)
} else {
quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
}
errWithCode := quota.preQuotaConsumption()
if errWithCode != nil {
@@ -47,21 +55,6 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
return quota, nil
}
func (q *Quota) init(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio[0] * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
}
func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
@@ -97,11 +90,17 @@ func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := q.modelRatio[1] * q.groupRatio
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil(((float64(promptTokens) * q.ratio) + (float64(completionTokens) * completionRatio))))
if q.ratio != 0 && quota <= 0 {
if q.price.Type == model.TimesPriceType {
quota = int(1000 * q.inputRatio)
} else {
completionRatio := q.price.GetOutput() * q.groupRatio
quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))))
}
if q.inputRatio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
@@ -129,13 +128,18 @@ func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string,
}
}
var modelRatioStr string
if q.modelRatio[0] == q.modelRatio[1] {
modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0])
if q.price.Type == model.TimesPriceType {
modelRatioStr = fmt.Sprintf("$%g/次", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1])
// 如果输入费率和输出费率一样,则只显示一个费率
if q.price.GetInput() == q.price.GetOutput() {
modelRatioStr = fmt.Sprintf("$%g/1k", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
modelRatioStr = fmt.Sprintf("$%g/1k (输入) | $%g/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate))
}
}
logContent := fmt.Sprintf("模型率 %s分组倍率 %.2f", modelRatioStr, q.groupRatio)
logContent := fmt.Sprintf("模型率 %s分组倍率 %.2f", modelRatioStr, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)

28
relay/util/type.go Normal file
View File

@@ -0,0 +1,28 @@
package util
import "one-api/common"
var UnknownOwnedBy = "未知"
var ModelOwnedBy map[int]string
func init() {
ModelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
}
}