🔖 chore: Rename relay/util to relay/relay_util package and add utils package

This commit is contained in:
MartialBE
2024-05-29 00:36:54 +08:00
parent 853f2681f4
commit 79524108a3
61 changed files with 309 additions and 265 deletions

138
relay/relay_util/cache.go Normal file
View File

@@ -0,0 +1,138 @@
package relay_util
import (
"crypto/md5"
"encoding/hex"
"fmt"
"one-api/common"
"one-api/common/utils"
"one-api/model"
"github.com/gin-gonic/gin"
)
type ChatCacheProps struct {
UserId int `json:"user_id"`
TokenId int `json:"token_id"`
ChannelID int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
Response string `json:"response"`
Hash string `json:"-"`
Cache bool `json:"-"`
Driver CacheDriver `json:"-"`
}
type CacheDriver interface {
Get(hash string, userId int) *ChatCacheProps
Set(hash string, props *ChatCacheProps, expire int64) error
}
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
caches, err := model.GetChatCacheListByUserId(userId)
if err != nil {
return nil, err
}
var props []*ChatCacheProps
for _, cache := range caches {
prop, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
continue
}
props = append(props, &prop)
}
return props, nil
}
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
props := &ChatCacheProps{
Cache: false,
}
if !allow {
return props
}
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
props.Cache = true
}
if common.RedisEnabled {
props.Driver = &ChatCacheRedis{}
} else {
props.Driver = &ChatCacheDB{}
}
props.UserId = c.GetInt("id")
props.TokenId = c.GetInt("token_id")
return props
}
func (p *ChatCacheProps) SetHash(request any) {
if !p.needCache() || request == nil {
return
}
p.hash(utils.Marshal(request))
}
func (p *ChatCacheProps) SetResponse(response any) {
if !p.needCache() || response == nil {
return
}
if str, ok := response.(string); ok {
p.Response += str
return
}
responseStr := utils.Marshal(response)
if responseStr == "" {
return
}
p.Response = responseStr
}
func (p *ChatCacheProps) NoCache() {
p.Cache = false
}
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
if !p.needCache() || p.Response == "" {
return nil
}
p.ChannelID = channelId
p.PromptTokens = promptTokens
p.CompletionTokens = completionTokens
p.ModelName = modelName
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
}
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
if !p.needCache() {
return nil
}
return p.Driver.Get(p.getHash(), p.UserId)
}
func (p *ChatCacheProps) needCache() bool {
return common.ChatCacheEnabled && p.Cache
}
func (p *ChatCacheProps) getHash() string {
return p.Hash
}
func (p *ChatCacheProps) hash(request string) {
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
p.Hash = hex.EncodeToString(hash[:])
}

View File

@@ -0,0 +1,47 @@
package relay_util
import (
"errors"
"one-api/common/utils"
"one-api/model"
"time"
)
type ChatCacheDB struct{}
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
cache, _ := model.GetChatCache(hash, userId)
if cache == nil {
return nil
}
props, err := utils.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
return nil
}
return &props
}
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
return SetCacheDB(hash, props, expire)
}
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
data := utils.Marshal(props)
if data == "" {
return errors.New("marshal error")
}
expire = expire * 60
expire += time.Now().Unix()
cache := &model.ChatCache{
Hash: hash,
UserId: props.UserId,
Data: data,
Expiration: expire,
}
return cache.Insert()
}

View File

@@ -0,0 +1,45 @@
package relay_util
import (
"errors"
"fmt"
"one-api/common"
"one-api/common/utils"
"time"
)
type ChatCacheRedis struct{}
var chatCacheKey = "chat_cache"
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
cache, err := common.RedisGet(r.getKey(hash, userId))
if err != nil {
return nil
}
props, err := utils.UnmarshalString[ChatCacheProps](cache)
if err != nil {
return nil
}
return &props
}
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
if !props.Cache {
return nil
}
data := utils.Marshal(&props)
if data == "" {
return errors.New("marshal error")
}
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
}
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
}

403
relay/relay_util/pricing.go Normal file
View File

@@ -0,0 +1,403 @@
package relay_util
import (
"encoding/json"
"errors"
"one-api/common"
"one-api/common/utils"
"one-api/model"
"sort"
"strings"
"sync"
"github.com/spf13/viper"
)
// PricingInstance is the Pricing instance
var PricingInstance *Pricing
// Pricing is a struct that contains the pricing data
type Pricing struct {
sync.RWMutex
Prices map[string]*model.Price `json:"models"`
Match []string `json:"-"`
}
type BatchPrices struct {
Models []string `json:"models" binding:"required"`
Price model.Price `json:"price" binding:"required"`
}
// NewPricing creates a new Pricing instance
func NewPricing() {
common.SysLog("Initializing Pricing")
PricingInstance = &Pricing{
Prices: make(map[string]*model.Price),
Match: make([]string, 0),
}
err := PricingInstance.Init()
if err != nil {
common.SysError("Failed to initialize Pricing:" + err.Error())
return
}
// 初始化时,需要检测是否有更新
if viper.GetBool("auto_price_updates") || len(PricingInstance.Prices) == 0 {
common.SysLog("Checking for pricing updates")
prices := model.GetDefaultPrice()
PricingInstance.SyncPricing(prices, false)
common.SysLog("Pricing initialized")
}
}
// initializes the Pricing instance
func (p *Pricing) Init() error {
prices, err := model.GetAllPrices()
if err != nil {
return err
}
if len(prices) == 0 {
return nil
}
newPrices := make(map[string]*model.Price)
newMatch := make(map[string]bool)
for _, price := range prices {
newPrices[price.Model] = price
if strings.HasSuffix(price.Model, "*") {
if _, ok := newMatch[price.Model]; !ok {
newMatch[price.Model] = true
}
}
}
var newMatchList []string
for match := range newMatch {
newMatchList = append(newMatchList, match)
}
p.Lock()
defer p.Unlock()
p.Prices = newPrices
p.Match = newMatchList
return nil
}
// GetPrice returns the price of a model
func (p *Pricing) GetPrice(modelName string) *model.Price {
p.RLock()
defer p.RUnlock()
if price, ok := p.Prices[modelName]; ok {
return price
}
matchModel := utils.GetModelsWithMatch(&p.Match, modelName)
if price, ok := p.Prices[matchModel]; ok {
return price
}
return &model.Price{
Type: model.TokensPriceType,
ChannelType: common.ChannelTypeUnknown,
Input: model.DefaultPrice,
Output: model.DefaultPrice,
}
}
func (p *Pricing) GetAllPrices() map[string]*model.Price {
return p.Prices
}
func (p *Pricing) GetAllPricesList() []*model.Price {
var prices []*model.Price
for _, price := range p.Prices {
prices = append(prices, price)
}
return prices
}
func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error {
if _, ok := p.Prices[modelName]; !ok {
return errors.New("model not found")
}
if _, ok := p.Prices[price.Model]; modelName != price.Model && ok {
return errors.New("model names cannot be duplicated")
}
if err := p.deleteRawPrice(modelName); err != nil {
return err
}
return price.Insert()
}
// UpdatePrice updates the price of a model
func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error {
if err := p.updateRawPrice(modelName, price); err != nil {
return err
}
err := p.Init()
return err
}
func (p *Pricing) addRawPrice(price *model.Price) error {
if _, ok := p.Prices[price.Model]; ok {
return errors.New("model already exists")
}
return price.Insert()
}
// AddPrice adds a new price to the Pricing instance
func (p *Pricing) AddPrice(price *model.Price) error {
if err := p.addRawPrice(price); err != nil {
return err
}
err := p.Init()
return err
}
func (p *Pricing) deleteRawPrice(modelName string) error {
item, ok := p.Prices[modelName]
if !ok {
return errors.New("model not found")
}
return item.Delete()
}
// DeletePrice deletes a price from the Pricing instance
func (p *Pricing) DeletePrice(modelName string) error {
if err := p.deleteRawPrice(modelName); err != nil {
return err
}
err := p.Init()
return err
}
// SyncPricing syncs the pricing data
func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error {
var err error
if overwrite {
err = p.SyncPriceWithOverwrite(pricing)
} else {
err = p.SyncPriceWithoutOverwrite(pricing)
}
return err
}
// SyncPriceWithOverwrite syncs the pricing data with overwrite
func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error {
tx := model.DB.Begin()
err := model.DeleteAllPrices(tx)
if err != nil {
tx.Rollback()
return err
}
err = model.InsertPrices(tx, pricing)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return p.Init()
}
// SyncPriceWithoutOverwrite syncs the pricing data without overwrite
func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error {
var newPrices []*model.Price
for _, price := range pricing {
if _, ok := p.Prices[price.Model]; !ok {
newPrices = append(newPrices, price)
}
}
if len(newPrices) == 0 {
return nil
}
tx := model.DB.Begin()
err := model.InsertPrices(tx, newPrices)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return p.Init()
}
// BatchDeletePrices deletes the prices of multiple models
func (p *Pricing) BatchDeletePrices(models []string) error {
tx := model.DB.Begin()
err := model.DeletePrices(tx, models)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
p.Lock()
defer p.Unlock()
for _, model := range models {
delete(p.Prices, model)
}
return nil
}
func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error {
// 查找需要删除的model
var deletePrices []string
var addPrices []*model.Price
var updatePrices []string
for _, model := range originalModels {
if !utils.Contains(model, batchPrices.Models) {
deletePrices = append(deletePrices, model)
} else {
updatePrices = append(updatePrices, model)
}
}
for _, model := range batchPrices.Models {
if !utils.Contains(model, originalModels) {
addPrice := batchPrices.Price
addPrice.Model = model
addPrices = append(addPrices, &addPrice)
}
}
tx := model.DB.Begin()
if len(addPrices) > 0 {
err := model.InsertPrices(tx, addPrices)
if err != nil {
tx.Rollback()
return err
}
}
if len(updatePrices) > 0 {
err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price)
if err != nil {
tx.Rollback()
return err
}
}
if len(deletePrices) > 0 {
err := model.DeletePrices(tx, deletePrices)
if err != nil {
tx.Rollback()
return err
}
}
tx.Commit()
return p.Init()
}
func GetPricesList(pricingType string) []*model.Price {
var prices []*model.Price
switch pricingType {
case "default":
prices = model.GetDefaultPrice()
case "db":
prices = PricingInstance.GetAllPricesList()
case "old":
prices = GetOldPricesList()
default:
return nil
}
sort.Slice(prices, func(i, j int) bool {
if prices[i].ChannelType == prices[j].ChannelType {
return prices[i].Model < prices[j].Model
}
return prices[i].ChannelType < prices[j].ChannelType
})
return prices
}
func GetOldPricesList() []*model.Price {
oldDataJson, err := model.GetOption("ModelRatio")
if err != nil || oldDataJson.Value == "" {
return nil
}
oldData := make(map[string][]float64)
err = json.Unmarshal([]byte(oldDataJson.Value), &oldData)
if err != nil {
return nil
}
var prices []*model.Price
for modelName, oldPrice := range oldData {
price := PricingInstance.GetPrice(modelName)
prices = append(prices, &model.Price{
Model: modelName,
Type: model.TokensPriceType,
ChannelType: price.ChannelType,
Input: oldPrice[0],
Output: oldPrice[1],
})
}
return prices
}
// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices {
// batchPricesMap := make(map[string]*BatchPrices)
// for _, price := range prices {
// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output)
// batchPrice, exists := batchPricesMap[key]
// if exists {
// batchPrice.Models = append(batchPrice.Models, price.Model)
// } else {
// batchPricesMap[key] = &BatchPrices{
// Models: []string{price.Model},
// Price: *price,
// }
// }
// }
// var batchPrices []*BatchPrices
// for _, batchPrice := range batchPricesMap {
// batchPrices = append(batchPrices, batchPrice)
// }
// return batchPrices
// }

176
relay/relay_util/quota.go Normal file
View File

@@ -0,0 +1,176 @@
package relay_util
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
type Quota struct {
modelName string
promptTokens int
price model.Price
groupRatio float64
inputRatio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) {
quota := &Quota{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quota.price = *PricingInstance.GetPrice(quota.modelName)
quota.groupRatio = common.GetGroupRatio(c.GetString("group"))
quota.inputRatio = quota.price.GetInput() * quota.groupRatio
if quota.price.Type == model.TimesPriceType {
quota.preConsumedQuota = int(1000 * quota.inputRatio)
} else {
quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
}
errWithCode := quota.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
}
return quota, nil
}
func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
}
return nil
}
func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
if q.price.Type == model.TimesPriceType {
quota = int(1000 * q.inputRatio)
} else {
completionRatio := q.price.GetOutput() * q.groupRatio
quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))))
}
if q.inputRatio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
requestTime := 0
requestStartTimeValue := ctx.Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
var modelRatioStr string
if q.price.Type == model.TimesPriceType {
modelRatioStr = fmt.Sprintf("$%s/次", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
// 如果输入费率和输出费率一样,则只显示一个费率
if q.price.GetInput() == q.price.GetOutput() {
modelRatioStr = fmt.Sprintf("$%s/1k", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
modelRatioStr = fmt.Sprintf("$%s/1k (输入) | $%s/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate))
}
}
logContent := fmt.Sprintf("模型费率 %s分组倍率 %.2f", modelRatioStr, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
return nil
}
func (q *Quota) Undo(c *gin.Context) {
tokenId := c.GetInt("token_id")
if q.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
}
func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err := q.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
func (q *Quota) GetInputRatio() float64 {
return q.inputRatio
}

35
relay/relay_util/type.go Normal file
View File

@@ -0,0 +1,35 @@
package relay_util
import "one-api/common"
var UnknownOwnedBy = "未知"
var ModelOwnedBy map[int]string
func init() {
ModelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeMidjourney: "Midjourney",
common.ChannelTypeCloudflareAI: "Cloudflare AI",
common.ChannelTypeCohere: "Cohere",
common.ChannelTypeStabilityAI: "Stability AI",
common.ChannelTypeCoze: "Coze",
common.ChannelTypeOllama: "Ollama",
common.ChannelTypeHunyuan: "Hunyuan",
}
}