mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 05:33:42 +08:00
🔖 chore: Rename relay/util to relay/relay_util package and add utils package
This commit is contained in:
138
relay/relay_util/cache.go
Normal file
138
relay/relay_util/cache.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package relay_util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/common/utils"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatCacheProps struct {
|
||||
UserId int `json:"user_id"`
|
||||
TokenId int `json:"token_id"`
|
||||
ChannelID int `json:"channel_id"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ModelName string `json:"model_name"`
|
||||
Response string `json:"response"`
|
||||
|
||||
Hash string `json:"-"`
|
||||
Cache bool `json:"-"`
|
||||
Driver CacheDriver `json:"-"`
|
||||
}
|
||||
|
||||
type CacheDriver interface {
|
||||
Get(hash string, userId int) *ChatCacheProps
|
||||
Set(hash string, props *ChatCacheProps, expire int64) error
|
||||
}
|
||||
|
||||
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
|
||||
caches, err := model.GetChatCacheListByUserId(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var props []*ChatCacheProps
|
||||
for _, cache := range caches {
|
||||
prop, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
props = append(props, &prop)
|
||||
}
|
||||
|
||||
return props, nil
|
||||
}
|
||||
|
||||
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
|
||||
props := &ChatCacheProps{
|
||||
Cache: false,
|
||||
}
|
||||
|
||||
if !allow {
|
||||
return props
|
||||
}
|
||||
|
||||
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
|
||||
props.Cache = true
|
||||
}
|
||||
|
||||
if common.RedisEnabled {
|
||||
props.Driver = &ChatCacheRedis{}
|
||||
} else {
|
||||
props.Driver = &ChatCacheDB{}
|
||||
}
|
||||
|
||||
props.UserId = c.GetInt("id")
|
||||
props.TokenId = c.GetInt("token_id")
|
||||
|
||||
return props
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) SetHash(request any) {
|
||||
if !p.needCache() || request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.hash(utils.Marshal(request))
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) SetResponse(response any) {
|
||||
if !p.needCache() || response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := response.(string); ok {
|
||||
p.Response += str
|
||||
return
|
||||
}
|
||||
|
||||
responseStr := utils.Marshal(response)
|
||||
if responseStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.Response = responseStr
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) NoCache() {
|
||||
p.Cache = false
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
|
||||
if !p.needCache() || p.Response == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ChannelID = channelId
|
||||
p.PromptTokens = promptTokens
|
||||
p.CompletionTokens = completionTokens
|
||||
p.ModelName = modelName
|
||||
|
||||
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
|
||||
if !p.needCache() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.Driver.Get(p.getHash(), p.UserId)
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) needCache() bool {
|
||||
return common.ChatCacheEnabled && p.Cache
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) getHash() string {
|
||||
return p.Hash
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) hash(request string) {
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
|
||||
p.Hash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
47
relay/relay_util/cache_db.go
Normal file
47
relay/relay_util/cache_db.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package relay_util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common/utils"
|
||||
"one-api/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatCacheDB struct{}
|
||||
|
||||
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
|
||||
cache, _ := model.GetChatCache(hash, userId)
|
||||
if cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &props
|
||||
}
|
||||
|
||||
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
|
||||
return SetCacheDB(hash, props, expire)
|
||||
}
|
||||
|
||||
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
|
||||
data := utils.Marshal(props)
|
||||
if data == "" {
|
||||
return errors.New("marshal error")
|
||||
}
|
||||
|
||||
expire = expire * 60
|
||||
expire += time.Now().Unix()
|
||||
|
||||
cache := &model.ChatCache{
|
||||
Hash: hash,
|
||||
UserId: props.UserId,
|
||||
Data: data,
|
||||
Expiration: expire,
|
||||
}
|
||||
|
||||
return cache.Insert()
|
||||
}
|
||||
45
relay/relay_util/cache_redis.go
Normal file
45
relay/relay_util/cache_redis.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package relay_util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/common/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatCacheRedis struct{}
|
||||
|
||||
var chatCacheKey = "chat_cache"
|
||||
|
||||
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
|
||||
cache, err := common.RedisGet(r.getKey(hash, userId))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props, err := utils.UnmarshalString[ChatCacheProps](cache)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &props
|
||||
}
|
||||
|
||||
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
|
||||
|
||||
if !props.Cache {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := utils.Marshal(&props)
|
||||
if data == "" {
|
||||
return errors.New("marshal error")
|
||||
}
|
||||
|
||||
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
|
||||
}
|
||||
|
||||
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
|
||||
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
|
||||
}
|
||||
403
relay/relay_util/pricing.go
Normal file
403
relay/relay_util/pricing.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package relay_util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/common/utils"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// PricingInstance is the Pricing instance
|
||||
var PricingInstance *Pricing
|
||||
|
||||
// Pricing is a struct that contains the pricing data
|
||||
type Pricing struct {
|
||||
sync.RWMutex
|
||||
Prices map[string]*model.Price `json:"models"`
|
||||
Match []string `json:"-"`
|
||||
}
|
||||
|
||||
type BatchPrices struct {
|
||||
Models []string `json:"models" binding:"required"`
|
||||
Price model.Price `json:"price" binding:"required"`
|
||||
}
|
||||
|
||||
// NewPricing creates a new Pricing instance
|
||||
func NewPricing() {
|
||||
common.SysLog("Initializing Pricing")
|
||||
|
||||
PricingInstance = &Pricing{
|
||||
Prices: make(map[string]*model.Price),
|
||||
Match: make([]string, 0),
|
||||
}
|
||||
|
||||
err := PricingInstance.Init()
|
||||
|
||||
if err != nil {
|
||||
common.SysError("Failed to initialize Pricing:" + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化时,需要检测是否有更新
|
||||
if viper.GetBool("auto_price_updates") || len(PricingInstance.Prices) == 0 {
|
||||
common.SysLog("Checking for pricing updates")
|
||||
prices := model.GetDefaultPrice()
|
||||
PricingInstance.SyncPricing(prices, false)
|
||||
common.SysLog("Pricing initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// initializes the Pricing instance
|
||||
func (p *Pricing) Init() error {
|
||||
prices, err := model.GetAllPrices()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(prices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
newPrices := make(map[string]*model.Price)
|
||||
newMatch := make(map[string]bool)
|
||||
|
||||
for _, price := range prices {
|
||||
newPrices[price.Model] = price
|
||||
if strings.HasSuffix(price.Model, "*") {
|
||||
if _, ok := newMatch[price.Model]; !ok {
|
||||
newMatch[price.Model] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newMatchList []string
|
||||
for match := range newMatch {
|
||||
newMatchList = append(newMatchList, match)
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
p.Prices = newPrices
|
||||
p.Match = newMatchList
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPrice returns the price of a model
|
||||
func (p *Pricing) GetPrice(modelName string) *model.Price {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
if price, ok := p.Prices[modelName]; ok {
|
||||
return price
|
||||
}
|
||||
|
||||
matchModel := utils.GetModelsWithMatch(&p.Match, modelName)
|
||||
if price, ok := p.Prices[matchModel]; ok {
|
||||
return price
|
||||
}
|
||||
|
||||
return &model.Price{
|
||||
Type: model.TokensPriceType,
|
||||
ChannelType: common.ChannelTypeUnknown,
|
||||
Input: model.DefaultPrice,
|
||||
Output: model.DefaultPrice,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pricing) GetAllPrices() map[string]*model.Price {
|
||||
return p.Prices
|
||||
}
|
||||
|
||||
func (p *Pricing) GetAllPricesList() []*model.Price {
|
||||
var prices []*model.Price
|
||||
for _, price := range p.Prices {
|
||||
prices = append(prices, price)
|
||||
}
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error {
|
||||
if _, ok := p.Prices[modelName]; !ok {
|
||||
return errors.New("model not found")
|
||||
}
|
||||
|
||||
if _, ok := p.Prices[price.Model]; modelName != price.Model && ok {
|
||||
return errors.New("model names cannot be duplicated")
|
||||
}
|
||||
|
||||
if err := p.deleteRawPrice(modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return price.Insert()
|
||||
}
|
||||
|
||||
// UpdatePrice updates the price of a model
|
||||
func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error {
|
||||
|
||||
if err := p.updateRawPrice(modelName, price); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := p.Init()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Pricing) addRawPrice(price *model.Price) error {
|
||||
if _, ok := p.Prices[price.Model]; ok {
|
||||
return errors.New("model already exists")
|
||||
}
|
||||
|
||||
return price.Insert()
|
||||
}
|
||||
|
||||
// AddPrice adds a new price to the Pricing instance
|
||||
func (p *Pricing) AddPrice(price *model.Price) error {
|
||||
if err := p.addRawPrice(price); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := p.Init()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Pricing) deleteRawPrice(modelName string) error {
|
||||
item, ok := p.Prices[modelName]
|
||||
if !ok {
|
||||
return errors.New("model not found")
|
||||
}
|
||||
|
||||
return item.Delete()
|
||||
}
|
||||
|
||||
// DeletePrice deletes a price from the Pricing instance
|
||||
func (p *Pricing) DeletePrice(modelName string) error {
|
||||
if err := p.deleteRawPrice(modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := p.Init()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncPricing syncs the pricing data
|
||||
func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error {
|
||||
var err error
|
||||
if overwrite {
|
||||
err = p.SyncPriceWithOverwrite(pricing)
|
||||
} else {
|
||||
err = p.SyncPriceWithoutOverwrite(pricing)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncPriceWithOverwrite syncs the pricing data with overwrite
|
||||
func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error {
|
||||
tx := model.DB.Begin()
|
||||
|
||||
err := model.DeleteAllPrices(tx)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
err = model.InsertPrices(tx, pricing)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return p.Init()
|
||||
}
|
||||
|
||||
// SyncPriceWithoutOverwrite syncs the pricing data without overwrite
|
||||
func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error {
|
||||
var newPrices []*model.Price
|
||||
|
||||
for _, price := range pricing {
|
||||
if _, ok := p.Prices[price.Model]; !ok {
|
||||
newPrices = append(newPrices, price)
|
||||
}
|
||||
}
|
||||
|
||||
if len(newPrices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx := model.DB.Begin()
|
||||
err := model.InsertPrices(tx, newPrices)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return p.Init()
|
||||
}
|
||||
|
||||
// BatchDeletePrices deletes the prices of multiple models
|
||||
func (p *Pricing) BatchDeletePrices(models []string) error {
|
||||
tx := model.DB.Begin()
|
||||
|
||||
err := model.DeletePrices(tx, models)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
for _, model := range models {
|
||||
delete(p.Prices, model)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error {
|
||||
// 查找需要删除的model
|
||||
var deletePrices []string
|
||||
var addPrices []*model.Price
|
||||
var updatePrices []string
|
||||
|
||||
for _, model := range originalModels {
|
||||
if !utils.Contains(model, batchPrices.Models) {
|
||||
deletePrices = append(deletePrices, model)
|
||||
} else {
|
||||
updatePrices = append(updatePrices, model)
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range batchPrices.Models {
|
||||
if !utils.Contains(model, originalModels) {
|
||||
addPrice := batchPrices.Price
|
||||
addPrice.Model = model
|
||||
addPrices = append(addPrices, &addPrice)
|
||||
}
|
||||
}
|
||||
|
||||
tx := model.DB.Begin()
|
||||
if len(addPrices) > 0 {
|
||||
err := model.InsertPrices(tx, addPrices)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(updatePrices) > 0 {
|
||||
err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(deletePrices) > 0 {
|
||||
err := model.DeletePrices(tx, deletePrices)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
return p.Init()
|
||||
}
|
||||
|
||||
func GetPricesList(pricingType string) []*model.Price {
|
||||
var prices []*model.Price
|
||||
|
||||
switch pricingType {
|
||||
case "default":
|
||||
prices = model.GetDefaultPrice()
|
||||
case "db":
|
||||
prices = PricingInstance.GetAllPricesList()
|
||||
case "old":
|
||||
prices = GetOldPricesList()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(prices, func(i, j int) bool {
|
||||
if prices[i].ChannelType == prices[j].ChannelType {
|
||||
return prices[i].Model < prices[j].Model
|
||||
}
|
||||
return prices[i].ChannelType < prices[j].ChannelType
|
||||
})
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
func GetOldPricesList() []*model.Price {
|
||||
oldDataJson, err := model.GetOption("ModelRatio")
|
||||
if err != nil || oldDataJson.Value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldData := make(map[string][]float64)
|
||||
err = json.Unmarshal([]byte(oldDataJson.Value), &oldData)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var prices []*model.Price
|
||||
for modelName, oldPrice := range oldData {
|
||||
price := PricingInstance.GetPrice(modelName)
|
||||
prices = append(prices, &model.Price{
|
||||
Model: modelName,
|
||||
Type: model.TokensPriceType,
|
||||
ChannelType: price.ChannelType,
|
||||
Input: oldPrice[0],
|
||||
Output: oldPrice[1],
|
||||
})
|
||||
}
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices {
|
||||
// batchPricesMap := make(map[string]*BatchPrices)
|
||||
// for _, price := range prices {
|
||||
// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output)
|
||||
// batchPrice, exists := batchPricesMap[key]
|
||||
// if exists {
|
||||
// batchPrice.Models = append(batchPrice.Models, price.Model)
|
||||
// } else {
|
||||
// batchPricesMap[key] = &BatchPrices{
|
||||
// Models: []string{price.Model},
|
||||
// Price: *price,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// var batchPrices []*BatchPrices
|
||||
// for _, batchPrice := range batchPricesMap {
|
||||
// batchPrices = append(batchPrices, batchPrice)
|
||||
// }
|
||||
|
||||
// return batchPrices
|
||||
// }
|
||||
176
relay/relay_util/quota.go
Normal file
176
relay/relay_util/quota.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package relay_util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Quota struct {
|
||||
modelName string
|
||||
promptTokens int
|
||||
price model.Price
|
||||
groupRatio float64
|
||||
inputRatio float64
|
||||
preConsumedQuota int
|
||||
userId int
|
||||
channelId int
|
||||
tokenId int
|
||||
HandelStatus bool
|
||||
}
|
||||
|
||||
func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) {
|
||||
quota := &Quota{
|
||||
modelName: modelName,
|
||||
promptTokens: promptTokens,
|
||||
userId: c.GetInt("id"),
|
||||
channelId: c.GetInt("channel_id"),
|
||||
tokenId: c.GetInt("token_id"),
|
||||
HandelStatus: false,
|
||||
}
|
||||
|
||||
quota.price = *PricingInstance.GetPrice(quota.modelName)
|
||||
quota.groupRatio = common.GetGroupRatio(c.GetString("group"))
|
||||
quota.inputRatio = quota.price.GetInput() * quota.groupRatio
|
||||
|
||||
if quota.price.Type == model.TimesPriceType {
|
||||
quota.preConsumedQuota = int(1000 * quota.inputRatio)
|
||||
} else {
|
||||
quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
|
||||
}
|
||||
|
||||
errWithCode := quota.preQuotaConsumption()
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
|
||||
userQuota, err := model.CacheGetUserQuota(q.userId)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if userQuota < q.preConsumedQuota {
|
||||
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if userQuota > 100*q.preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
q.preConsumedQuota = 0
|
||||
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
|
||||
if q.preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
q.HandelStatus = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
|
||||
quota := 0
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
|
||||
if q.price.Type == model.TimesPriceType {
|
||||
quota = int(1000 * q.inputRatio)
|
||||
} else {
|
||||
completionRatio := q.price.GetOutput() * q.groupRatio
|
||||
quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))))
|
||||
}
|
||||
|
||||
if q.inputRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - q.preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
return errors.New("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(q.userId)
|
||||
if err != nil {
|
||||
return errors.New("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
|
||||
requestTime := 0
|
||||
requestStartTimeValue := ctx.Value("requestStartTime")
|
||||
if requestStartTimeValue != nil {
|
||||
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
||||
if ok {
|
||||
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
||||
}
|
||||
}
|
||||
var modelRatioStr string
|
||||
if q.price.Type == model.TimesPriceType {
|
||||
modelRatioStr = fmt.Sprintf("$%s/次", q.price.FetchInputCurrencyPrice(model.DollarRate))
|
||||
} else {
|
||||
// 如果输入费率和输出费率一样,则只显示一个费率
|
||||
if q.price.GetInput() == q.price.GetOutput() {
|
||||
modelRatioStr = fmt.Sprintf("$%s/1k", q.price.FetchInputCurrencyPrice(model.DollarRate))
|
||||
} else {
|
||||
modelRatioStr = fmt.Sprintf("$%s/1k (输入) | $%s/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate))
|
||||
}
|
||||
}
|
||||
|
||||
logContent := fmt.Sprintf("模型费率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio)
|
||||
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
|
||||
model.UpdateChannelUsedQuota(q.channelId, quota)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Quota) Undo(c *gin.Context) {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if q.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
|
||||
tokenName := c.GetString("token_name")
|
||||
// 如果没有报错,则消费配额
|
||||
go func(ctx context.Context) {
|
||||
err := q.completedQuotaConsumption(usage, tokenName, ctx)
|
||||
if err != nil {
|
||||
common.LogError(ctx, err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
|
||||
func (q *Quota) GetInputRatio() float64 {
|
||||
return q.inputRatio
|
||||
}
|
||||
35
relay/relay_util/type.go
Normal file
35
relay/relay_util/type.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package relay_util
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var UnknownOwnedBy = "未知"
|
||||
var ModelOwnedBy map[int]string
|
||||
|
||||
func init() {
|
||||
ModelOwnedBy = map[int]string{
|
||||
common.ChannelTypeOpenAI: "OpenAI",
|
||||
common.ChannelTypeAnthropic: "Anthropic",
|
||||
common.ChannelTypeBaidu: "Baidu",
|
||||
common.ChannelTypePaLM: "Google PaLM",
|
||||
common.ChannelTypeGemini: "Google Gemini",
|
||||
common.ChannelTypeZhipu: "Zhipu",
|
||||
common.ChannelTypeAli: "Ali",
|
||||
common.ChannelTypeXunfei: "Xunfei",
|
||||
common.ChannelType360: "360",
|
||||
common.ChannelTypeTencent: "Tencent",
|
||||
common.ChannelTypeBaichuan: "Baichuan",
|
||||
common.ChannelTypeMiniMax: "MiniMax",
|
||||
common.ChannelTypeDeepseek: "Deepseek",
|
||||
common.ChannelTypeMoonshot: "Moonshot",
|
||||
common.ChannelTypeMistral: "Mistral",
|
||||
common.ChannelTypeGroq: "Groq",
|
||||
common.ChannelTypeLingyi: "Lingyiwanwu",
|
||||
common.ChannelTypeMidjourney: "Midjourney",
|
||||
common.ChannelTypeCloudflareAI: "Cloudflare AI",
|
||||
common.ChannelTypeCohere: "Cohere",
|
||||
common.ChannelTypeStabilityAI: "Stability AI",
|
||||
common.ChannelTypeCoze: "Coze",
|
||||
common.ChannelTypeOllama: "Ollama",
|
||||
common.ChannelTypeHunyuan: "Hunyuan",
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user