mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
402
relay/util/pricing.go
Normal file
402
relay/util/pricing.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// PricingInstance is the Pricing instance
|
||||
var PricingInstance *Pricing
|
||||
|
||||
// Pricing is a struct that contains the pricing data
|
||||
type Pricing struct {
|
||||
sync.RWMutex
|
||||
Prices map[string]*model.Price `json:"models"`
|
||||
Match []string `json:"-"`
|
||||
}
|
||||
|
||||
type BatchPrices struct {
|
||||
Models []string `json:"models" binding:"required"`
|
||||
Price model.Price `json:"price" binding:"required"`
|
||||
}
|
||||
|
||||
// NewPricing creates a new Pricing instance
|
||||
func NewPricing() {
|
||||
common.SysLog("Initializing Pricing")
|
||||
|
||||
PricingInstance = &Pricing{
|
||||
Prices: make(map[string]*model.Price),
|
||||
Match: make([]string, 0),
|
||||
}
|
||||
|
||||
err := PricingInstance.Init()
|
||||
|
||||
if err != nil {
|
||||
common.SysError("Failed to initialize Pricing:" + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化时,需要检测是否有更新
|
||||
if viper.GetBool("auto_price_updates") {
|
||||
common.SysLog("Checking for pricing updates")
|
||||
prices := model.GetDefaultPrice()
|
||||
PricingInstance.SyncPricing(prices, false)
|
||||
common.SysLog("Pricing initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// initializes the Pricing instance
|
||||
func (p *Pricing) Init() error {
|
||||
prices, err := model.GetAllPrices()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(prices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
newPrices := make(map[string]*model.Price)
|
||||
newMatch := make(map[string]bool)
|
||||
|
||||
for _, price := range prices {
|
||||
newPrices[price.Model] = price
|
||||
if strings.HasSuffix(price.Model, "*") {
|
||||
if _, ok := newMatch[price.Model]; !ok {
|
||||
newMatch[price.Model] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newMatchList []string
|
||||
for match := range newMatch {
|
||||
newMatchList = append(newMatchList, match)
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
p.Prices = newPrices
|
||||
p.Match = newMatchList
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPrice returns the price of a model
|
||||
func (p *Pricing) GetPrice(modelName string) *model.Price {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
if price, ok := p.Prices[modelName]; ok {
|
||||
return price
|
||||
}
|
||||
|
||||
matchModel := common.GetModelsWithMatch(&p.Match, modelName)
|
||||
if price, ok := p.Prices[matchModel]; ok {
|
||||
return price
|
||||
}
|
||||
|
||||
return &model.Price{
|
||||
Type: model.TokensPriceType,
|
||||
ChannelType: common.ChannelTypeUnknown,
|
||||
Input: model.DefaultPrice,
|
||||
Output: model.DefaultPrice,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pricing) GetAllPrices() map[string]*model.Price {
|
||||
return p.Prices
|
||||
}
|
||||
|
||||
func (p *Pricing) GetAllPricesList() []*model.Price {
|
||||
var prices []*model.Price
|
||||
for _, price := range p.Prices {
|
||||
prices = append(prices, price)
|
||||
}
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error {
|
||||
if _, ok := p.Prices[modelName]; !ok {
|
||||
return errors.New("model not found")
|
||||
}
|
||||
|
||||
if _, ok := p.Prices[price.Model]; modelName != price.Model && ok {
|
||||
return errors.New("model names cannot be duplicated")
|
||||
}
|
||||
|
||||
if err := p.deleteRawPrice(modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return price.Insert()
|
||||
}
|
||||
|
||||
// UpdatePrice updates the price of a model
|
||||
func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error {
|
||||
|
||||
if err := p.updateRawPrice(modelName, price); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := p.Init()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Pricing) addRawPrice(price *model.Price) error {
|
||||
if _, ok := p.Prices[price.Model]; ok {
|
||||
return errors.New("model already exists")
|
||||
}
|
||||
|
||||
return price.Insert()
|
||||
}
|
||||
|
||||
// AddPrice adds a new price to the Pricing instance
|
||||
func (p *Pricing) AddPrice(price *model.Price) error {
|
||||
if err := p.addRawPrice(price); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := p.Init()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Pricing) deleteRawPrice(modelName string) error {
|
||||
item, ok := p.Prices[modelName]
|
||||
if !ok {
|
||||
return errors.New("model not found")
|
||||
}
|
||||
|
||||
return item.Delete()
|
||||
}
|
||||
|
||||
// DeletePrice deletes a price from the Pricing instance
|
||||
func (p *Pricing) DeletePrice(modelName string) error {
|
||||
if err := p.deleteRawPrice(modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := p.Init()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncPricing syncs the pricing data
|
||||
func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error {
|
||||
var err error
|
||||
if overwrite {
|
||||
err = p.SyncPriceWithOverwrite(pricing)
|
||||
} else {
|
||||
err = p.SyncPriceWithoutOverwrite(pricing)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncPriceWithOverwrite syncs the pricing data with overwrite
|
||||
func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error {
|
||||
tx := model.DB.Begin()
|
||||
|
||||
err := model.DeleteAllPrices(tx)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = model.InsertPrices(tx, pricing)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return p.Init()
|
||||
}
|
||||
|
||||
// SyncPriceWithoutOverwrite syncs the pricing data without overwrite
|
||||
func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error {
|
||||
var newPrices []*model.Price
|
||||
|
||||
for _, price := range pricing {
|
||||
if _, ok := p.Prices[price.Model]; !ok {
|
||||
newPrices = append(newPrices, price)
|
||||
}
|
||||
}
|
||||
|
||||
if len(newPrices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx := model.DB.Begin()
|
||||
err := model.InsertPrices(tx, newPrices)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return p.Init()
|
||||
}
|
||||
|
||||
// BatchDeletePrices deletes the prices of multiple models
|
||||
func (p *Pricing) BatchDeletePrices(models []string) error {
|
||||
tx := model.DB.Begin()
|
||||
|
||||
err := model.DeletePrices(tx, models)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
for _, model := range models {
|
||||
delete(p.Prices, model)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error {
|
||||
// 查找需要删除的model
|
||||
var deletePrices []string
|
||||
var addPrices []*model.Price
|
||||
var updatePrices []string
|
||||
|
||||
for _, model := range originalModels {
|
||||
if !common.Contains(model, batchPrices.Models) {
|
||||
deletePrices = append(deletePrices, model)
|
||||
} else {
|
||||
updatePrices = append(updatePrices, model)
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range batchPrices.Models {
|
||||
if !common.Contains(model, originalModels) {
|
||||
addPrice := batchPrices.Price
|
||||
addPrice.Model = model
|
||||
addPrices = append(addPrices, &addPrice)
|
||||
}
|
||||
}
|
||||
|
||||
tx := model.DB.Begin()
|
||||
if len(addPrices) > 0 {
|
||||
err := model.InsertPrices(tx, addPrices)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(updatePrices) > 0 {
|
||||
err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(deletePrices) > 0 {
|
||||
err := model.DeletePrices(tx, deletePrices)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
return p.Init()
|
||||
}
|
||||
|
||||
func GetPricesList(pricingType string) []*model.Price {
|
||||
var prices []*model.Price
|
||||
|
||||
switch pricingType {
|
||||
case "default":
|
||||
prices = model.GetDefaultPrice()
|
||||
case "db":
|
||||
prices = PricingInstance.GetAllPricesList()
|
||||
case "old":
|
||||
prices = GetOldPricesList()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(prices, func(i, j int) bool {
|
||||
if prices[i].ChannelType == prices[j].ChannelType {
|
||||
return prices[i].Model < prices[j].Model
|
||||
}
|
||||
return prices[i].ChannelType < prices[j].ChannelType
|
||||
})
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
func GetOldPricesList() []*model.Price {
|
||||
oldDataJson, err := model.GetOption("ModelRatio")
|
||||
if err != nil || oldDataJson.Value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldData := make(map[string][]float64)
|
||||
err = json.Unmarshal([]byte(oldDataJson.Value), &oldData)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var prices []*model.Price
|
||||
for modelName, oldPrice := range oldData {
|
||||
price := PricingInstance.GetPrice(modelName)
|
||||
prices = append(prices, &model.Price{
|
||||
Model: modelName,
|
||||
Type: model.TokensPriceType,
|
||||
ChannelType: price.ChannelType,
|
||||
Input: oldPrice[0],
|
||||
Output: oldPrice[1],
|
||||
})
|
||||
}
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices {
|
||||
// batchPricesMap := make(map[string]*BatchPrices)
|
||||
// for _, price := range prices {
|
||||
// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output)
|
||||
// batchPrice, exists := batchPricesMap[key]
|
||||
// if exists {
|
||||
// batchPrice.Models = append(batchPrice.Models, price.Model)
|
||||
// } else {
|
||||
// batchPricesMap[key] = &BatchPrices{
|
||||
// Models: []string{price.Model},
|
||||
// Price: *price,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// var batchPrices []*BatchPrices
|
||||
// for _, batchPrice := range batchPricesMap {
|
||||
// batchPrices = append(batchPrices, batchPrice)
|
||||
// }
|
||||
|
||||
// return batchPrices
|
||||
// }
|
||||
@@ -15,17 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type Quota struct {
|
||||
modelName string
|
||||
promptTokens int
|
||||
preConsumedTokens int
|
||||
modelRatio []float64
|
||||
groupRatio float64
|
||||
ratio float64
|
||||
preConsumedQuota int
|
||||
userId int
|
||||
channelId int
|
||||
tokenId int
|
||||
HandelStatus bool
|
||||
modelName string
|
||||
promptTokens int
|
||||
price model.Price
|
||||
groupRatio float64
|
||||
inputRatio float64
|
||||
preConsumedQuota int
|
||||
userId int
|
||||
channelId int
|
||||
tokenId int
|
||||
HandelStatus bool
|
||||
}
|
||||
|
||||
func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) {
|
||||
@@ -37,7 +36,16 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
|
||||
tokenId: c.GetInt("token_id"),
|
||||
HandelStatus: false,
|
||||
}
|
||||
quota.init(c.GetString("group"))
|
||||
|
||||
quota.price = *PricingInstance.GetPrice(quota.modelName)
|
||||
quota.groupRatio = common.GetGroupRatio(c.GetString("group"))
|
||||
quota.inputRatio = quota.price.GetInput() * quota.groupRatio
|
||||
|
||||
if quota.price.Type == model.TimesPriceType {
|
||||
quota.preConsumedQuota = int(1000 * quota.inputRatio)
|
||||
} else {
|
||||
quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
|
||||
}
|
||||
|
||||
errWithCode := quota.preQuotaConsumption()
|
||||
if errWithCode != nil {
|
||||
@@ -47,21 +55,6 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
func (q *Quota) init(groupName string) {
|
||||
modelRatio := common.GetModelRatio(q.modelName)
|
||||
groupRatio := common.GetGroupRatio(groupName)
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
ratio := modelRatio[0] * groupRatio
|
||||
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
|
||||
|
||||
q.preConsumedTokens = preConsumedTokens
|
||||
q.modelRatio = modelRatio
|
||||
q.groupRatio = groupRatio
|
||||
q.ratio = ratio
|
||||
q.preConsumedQuota = preConsumedQuota
|
||||
|
||||
}
|
||||
|
||||
func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
|
||||
userQuota, err := model.CacheGetUserQuota(q.userId)
|
||||
if err != nil {
|
||||
@@ -97,11 +90,17 @@ func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
|
||||
quota := 0
|
||||
completionRatio := q.modelRatio[1] * q.groupRatio
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int(math.Ceil(((float64(promptTokens) * q.ratio) + (float64(completionTokens) * completionRatio))))
|
||||
if q.ratio != 0 && quota <= 0 {
|
||||
|
||||
if q.price.Type == model.TimesPriceType {
|
||||
quota = int(1000 * q.inputRatio)
|
||||
} else {
|
||||
completionRatio := q.price.GetOutput() * q.groupRatio
|
||||
quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))))
|
||||
}
|
||||
|
||||
if q.inputRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -129,13 +128,18 @@ func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string,
|
||||
}
|
||||
}
|
||||
var modelRatioStr string
|
||||
if q.modelRatio[0] == q.modelRatio[1] {
|
||||
modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0])
|
||||
if q.price.Type == model.TimesPriceType {
|
||||
modelRatioStr = fmt.Sprintf("$%g/次", q.price.FetchInputCurrencyPrice(model.DollarRate))
|
||||
} else {
|
||||
modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1])
|
||||
// 如果输入费率和输出费率一样,则只显示一个费率
|
||||
if q.price.GetInput() == q.price.GetOutput() {
|
||||
modelRatioStr = fmt.Sprintf("$%g/1k", q.price.FetchInputCurrencyPrice(model.DollarRate))
|
||||
} else {
|
||||
modelRatioStr = fmt.Sprintf("$%g/1k (输入) | $%g/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate))
|
||||
}
|
||||
}
|
||||
|
||||
logContent := fmt.Sprintf("模型倍率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio)
|
||||
logContent := fmt.Sprintf("模型费率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio)
|
||||
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
|
||||
model.UpdateChannelUsedQuota(q.channelId, quota)
|
||||
|
||||
28
relay/util/type.go
Normal file
28
relay/util/type.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package util
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var UnknownOwnedBy = "未知"
|
||||
var ModelOwnedBy map[int]string
|
||||
|
||||
func init() {
|
||||
ModelOwnedBy = map[int]string{
|
||||
common.ChannelTypeOpenAI: "OpenAI",
|
||||
common.ChannelTypeAnthropic: "Anthropic",
|
||||
common.ChannelTypeBaidu: "Baidu",
|
||||
common.ChannelTypePaLM: "Google PaLM",
|
||||
common.ChannelTypeGemini: "Google Gemini",
|
||||
common.ChannelTypeZhipu: "Zhipu",
|
||||
common.ChannelTypeAli: "Ali",
|
||||
common.ChannelTypeXunfei: "Xunfei",
|
||||
common.ChannelType360: "360",
|
||||
common.ChannelTypeTencent: "Tencent",
|
||||
common.ChannelTypeBaichuan: "Baichuan",
|
||||
common.ChannelTypeMiniMax: "MiniMax",
|
||||
common.ChannelTypeDeepseek: "Deepseek",
|
||||
common.ChannelTypeMoonshot: "Moonshot",
|
||||
common.ChannelTypeMistral: "Mistral",
|
||||
common.ChannelTypeGroq: "Groq",
|
||||
common.ChannelTypeLingyi: "Lingyiwanwu",
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user