mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			256 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			256 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package model
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"github.com/songquanpeng/one-api/common"
 | 
						|
	"github.com/songquanpeng/one-api/common/config"
 | 
						|
	"github.com/songquanpeng/one-api/common/logger"
 | 
						|
	"github.com/songquanpeng/one-api/common/random"
 | 
						|
	"math/rand"
 | 
						|
	"sort"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	TokenCacheSeconds         = config.SyncFrequency
 | 
						|
	UserId2GroupCacheSeconds  = config.SyncFrequency
 | 
						|
	UserId2QuotaCacheSeconds  = config.SyncFrequency
 | 
						|
	UserId2StatusCacheSeconds = config.SyncFrequency
 | 
						|
	GroupModelsCacheSeconds   = config.SyncFrequency
 | 
						|
)
 | 
						|
 | 
						|
func CacheGetTokenByKey(key string) (*Token, error) {
 | 
						|
	keyCol := "`key`"
 | 
						|
	if common.UsingPostgreSQL {
 | 
						|
		keyCol = `"key"`
 | 
						|
	}
 | 
						|
	var token Token
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 | 
						|
		return &token, err
 | 
						|
	}
 | 
						|
	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
 | 
						|
	if err != nil {
 | 
						|
		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		jsonBytes, err := json.Marshal(token)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
 | 
						|
		if err != nil {
 | 
						|
			logger.SysError("Redis set token error: " + err.Error())
 | 
						|
		}
 | 
						|
		return &token, nil
 | 
						|
	}
 | 
						|
	err = json.Unmarshal([]byte(tokenObjectString), &token)
 | 
						|
	return &token, err
 | 
						|
}
 | 
						|
 | 
						|
func CacheGetUserGroup(id int) (group string, err error) {
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		return GetUserGroup(id)
 | 
						|
	}
 | 
						|
	group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
 | 
						|
	if err != nil {
 | 
						|
		group, err = GetUserGroup(id)
 | 
						|
		if err != nil {
 | 
						|
			return "", err
 | 
						|
		}
 | 
						|
		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
 | 
						|
		if err != nil {
 | 
						|
			logger.SysError("Redis set user group error: " + err.Error())
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return group, err
 | 
						|
}
 | 
						|
 | 
						|
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
 | 
						|
	quota, err = GetUserQuota(id)
 | 
						|
	if err != nil {
 | 
						|
		return 0, err
 | 
						|
	}
 | 
						|
	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
 | 
						|
	if err != nil {
 | 
						|
		logger.Error(ctx, "Redis set user quota error: "+err.Error())
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		return GetUserQuota(id)
 | 
						|
	}
 | 
						|
	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
 | 
						|
	if err != nil {
 | 
						|
		return fetchAndUpdateUserQuota(ctx, id)
 | 
						|
	}
 | 
						|
	quota, err = strconv.ParseInt(quotaString, 10, 64)
 | 
						|
	if err != nil {
 | 
						|
		return 0, nil
 | 
						|
	}
 | 
						|
	if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
 | 
						|
		logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
 | 
						|
		return fetchAndUpdateUserQuota(ctx, id)
 | 
						|
	}
 | 
						|
	return quota, nil
 | 
						|
}
 | 
						|
 | 
						|
func CacheUpdateUserQuota(ctx context.Context, id int) error {
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	quota, err := CacheGetUserQuota(ctx, id)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func CacheDecreaseUserQuota(id int, quota int64) error {
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func CacheIsUserEnabled(userId int) (bool, error) {
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		return IsUserEnabled(userId)
 | 
						|
	}
 | 
						|
	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
 | 
						|
	if err == nil {
 | 
						|
		return enabled == "1", nil
 | 
						|
	}
 | 
						|
 | 
						|
	userEnabled, err := IsUserEnabled(userId)
 | 
						|
	if err != nil {
 | 
						|
		return false, err
 | 
						|
	}
 | 
						|
	enabled = "0"
 | 
						|
	if userEnabled {
 | 
						|
		enabled = "1"
 | 
						|
	}
 | 
						|
	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
						|
	if err != nil {
 | 
						|
		logger.SysError("Redis set user enabled error: " + err.Error())
 | 
						|
	}
 | 
						|
	return userEnabled, err
 | 
						|
}
 | 
						|
 | 
						|
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
 | 
						|
	if !common.RedisEnabled {
 | 
						|
		return GetGroupModels(ctx, group)
 | 
						|
	}
 | 
						|
	modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
 | 
						|
	if err == nil {
 | 
						|
		return strings.Split(modelsStr, ","), nil
 | 
						|
	}
 | 
						|
	models, err := GetGroupModels(ctx, group)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
 | 
						|
	if err != nil {
 | 
						|
		logger.SysError("Redis set group models error: " + err.Error())
 | 
						|
	}
 | 
						|
	return models, nil
 | 
						|
}
 | 
						|
 | 
						|
var group2model2channels map[string]map[string][]*Channel
 | 
						|
var channelSyncLock sync.RWMutex
 | 
						|
 | 
						|
func InitChannelCache() {
 | 
						|
	newChannelId2channel := make(map[int]*Channel)
 | 
						|
	var channels []*Channel
 | 
						|
	DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
 | 
						|
	for _, channel := range channels {
 | 
						|
		newChannelId2channel[channel.Id] = channel
 | 
						|
	}
 | 
						|
	var abilities []*Ability
 | 
						|
	DB.Find(&abilities)
 | 
						|
	groups := make(map[string]bool)
 | 
						|
	for _, ability := range abilities {
 | 
						|
		groups[ability.Group] = true
 | 
						|
	}
 | 
						|
	newGroup2model2channels := make(map[string]map[string][]*Channel)
 | 
						|
	for group := range groups {
 | 
						|
		newGroup2model2channels[group] = make(map[string][]*Channel)
 | 
						|
	}
 | 
						|
	for _, channel := range channels {
 | 
						|
		groups := strings.Split(channel.Group, ",")
 | 
						|
		for _, group := range groups {
 | 
						|
			models := strings.Split(channel.Models, ",")
 | 
						|
			for _, model := range models {
 | 
						|
				if _, ok := newGroup2model2channels[group][model]; !ok {
 | 
						|
					newGroup2model2channels[group][model] = make([]*Channel, 0)
 | 
						|
				}
 | 
						|
				newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// sort by priority
 | 
						|
	for group, model2channels := range newGroup2model2channels {
 | 
						|
		for model, channels := range model2channels {
 | 
						|
			sort.Slice(channels, func(i, j int) bool {
 | 
						|
				return channels[i].GetPriority() > channels[j].GetPriority()
 | 
						|
			})
 | 
						|
			newGroup2model2channels[group][model] = channels
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	channelSyncLock.Lock()
 | 
						|
	group2model2channels = newGroup2model2channels
 | 
						|
	channelSyncLock.Unlock()
 | 
						|
	logger.SysLog("channels synced from database")
 | 
						|
}
 | 
						|
 | 
						|
func SyncChannelCache(frequency int) {
 | 
						|
	for {
 | 
						|
		time.Sleep(time.Duration(frequency) * time.Second)
 | 
						|
		logger.SysLog("syncing channels from database")
 | 
						|
		InitChannelCache()
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
 | 
						|
	if !config.MemoryCacheEnabled {
 | 
						|
		return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
 | 
						|
	}
 | 
						|
	channelSyncLock.RLock()
 | 
						|
	defer channelSyncLock.RUnlock()
 | 
						|
	channels := group2model2channels[group][model]
 | 
						|
	if len(channels) == 0 {
 | 
						|
		return nil, errors.New("channel not found")
 | 
						|
	}
 | 
						|
	endIdx := len(channels)
 | 
						|
	// choose by priority
 | 
						|
	firstChannel := channels[0]
 | 
						|
	if firstChannel.GetPriority() > 0 {
 | 
						|
		for i := range channels {
 | 
						|
			if channels[i].GetPriority() != firstChannel.GetPriority() {
 | 
						|
				endIdx = i
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	idx := rand.Intn(endIdx)
 | 
						|
	if ignoreFirstPriority {
 | 
						|
		if endIdx < len(channels) { // which means there are more than one priority
 | 
						|
			idx = random.RandRange(endIdx, len(channels))
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return channels[idx], nil
 | 
						|
}
 |