mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-13 11:53:42 +08:00
refactor: add GetRatio to Adaptor
This commit is contained in:
@@ -3,6 +3,7 @@ package billing
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
@@ -19,20 +20,22 @@ func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId
|
||||
}
|
||||
}
|
||||
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string, promptTokens int, completionTokens int) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, userId)
|
||||
err := model.CacheUpdateUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
|
||||
@@ -9,17 +9,35 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
USD2RMB = 7
|
||||
USD = 500 // $0.002 = 1 -> $1 = 500
|
||||
RMB = USD / USD2RMB
|
||||
USD2RMB = 7
|
||||
USD = 500 // $0.002 = 1 -> $1 = 500
|
||||
RMB = USD / USD2RMB // 1RMB = 1/7USD
|
||||
MILLI_USD = 1.0 / 1000 * USD
|
||||
MILLI_RMB = 1.0 / 1000 * RMB
|
||||
TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
|
||||
)
|
||||
|
||||
type Ratio struct {
|
||||
Input float64 `json:"input,omitempty"` // input ratio
|
||||
Output float64 `json:"output,omitempty"` // output ratio
|
||||
LongThreshold int `json:"long_threshold,omitempty"` // for gemini like models, prompt longer than threshold will be charged as long input
|
||||
LongInput float64 `json:"long_input,omitempty"` // long input ratio
|
||||
LongOutput float64 `json:"long_output,omitempty"` // long output ratio
|
||||
}
|
||||
|
||||
var (
|
||||
FallbackRatio = Ratio{Input: 30, Output: 30}
|
||||
)
|
||||
|
||||
// Deprecated
|
||||
// TODO: remove this
|
||||
// ModelRatio
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||
// https://openai.com/pricing
|
||||
// 1 === $0.002 / 1K tokens
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
// 1 === $0.002 / 20 seconds (50 tokens per second)
|
||||
var ModelRatio = map[string]float64{
|
||||
// https://openai.com/pricing
|
||||
"gpt-4": 15,
|
||||
@@ -342,6 +360,7 @@ var CompletionRatio = map[string]float64{
|
||||
var (
|
||||
DefaultModelRatio map[string]float64
|
||||
DefaultCompletionRatio map[string]float64
|
||||
DefaultRatio = make(map[string]Ratio)
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -536,3 +555,30 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func Ratio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(DefaultRatio)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateRatioByJSONString(jsonStr string) error {
|
||||
DefaultRatio = make(map[string]Ratio)
|
||||
return json.Unmarshal([]byte(jsonStr), &DefaultRatio)
|
||||
}
|
||||
|
||||
func GetRatio(name string, channelType int) *Ratio {
|
||||
var result Ratio
|
||||
model := fmt.Sprintf("%s(%d)", name, channelType)
|
||||
if ratio, ok := DefaultRatio[model]; ok {
|
||||
result = ratio
|
||||
return &result
|
||||
}
|
||||
if ratio, ok := DefaultRatio[name]; ok {
|
||||
result = ratio
|
||||
return &result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user