mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-18 14:13:43 +08:00
✨ feat: support configuration file (#117)
* ♻️ refactor: move file directory * ♻️ refactor: move file directory * ♻️ refactor: support multiple config methods * 🔥 del: remove unused code * 💩 refactor: Refactor channel management and synchronization * 💄 improve: add channel website * ✨ feat: allow recording 0 consumption
This commit is contained in:
@@ -34,7 +34,7 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel {
|
||||
func (cc *ChannelsChooser) balancer(channelIds []int) *Channel {
|
||||
nowTime := time.Now().Unix()
|
||||
totalWeight := 0
|
||||
|
||||
@@ -67,9 +67,9 @@ func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) {
|
||||
func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
return GetRandomSatisfiedChannel(group, modelName)
|
||||
}
|
||||
cc.RLock()
|
||||
defer cc.RUnlock()
|
||||
@@ -77,17 +77,17 @@ func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
if _, ok := cc.Rule[group][model]; !ok {
|
||||
if _, ok := cc.Rule[group][modelName]; !ok {
|
||||
return nil, errors.New("model not found")
|
||||
}
|
||||
|
||||
channelsPriority := cc.Rule[group][model]
|
||||
channelsPriority := cc.Rule[group][modelName]
|
||||
if len(channelsPriority) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
|
||||
for _, priority := range channelsPriority {
|
||||
channel := cc.Balancer(priority)
|
||||
channel := cc.balancer(priority)
|
||||
if channel != nil {
|
||||
return channel, nil
|
||||
}
|
||||
@@ -118,7 +118,7 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
|
||||
|
||||
var ChannelGroup = ChannelsChooser{}
|
||||
|
||||
func InitChannelGroup() {
|
||||
func (cc *ChannelsChooser) Load() {
|
||||
var channels []*Channel
|
||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||
|
||||
@@ -160,17 +160,9 @@ func InitChannelGroup() {
|
||||
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
|
||||
}
|
||||
|
||||
ChannelGroup.Lock()
|
||||
ChannelGroup.Rule = newGroup
|
||||
ChannelGroup.Channels = newChannels
|
||||
ChannelGroup.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelGroup(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
InitChannelGroup()
|
||||
}
|
||||
cc.Lock()
|
||||
cc.Rule = newGroup
|
||||
cc.Channels = newChannels
|
||||
cc.Unlock()
|
||||
common.SysLog("channels Load success")
|
||||
}
|
||||
|
||||
@@ -9,10 +9,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
TokenCacheSeconds = 0
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
@@ -55,7 +52,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
@@ -73,7 +70,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user quota error: " + err.Error())
|
||||
}
|
||||
@@ -91,7 +88,7 @@ func CacheUpdateUserQuota(id int) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -120,7 +117,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -117,6 +117,8 @@ func BatchInsertChannels(channels []Channel) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
go ChannelGroup.Load()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -130,6 +132,10 @@ func BatchUpdateChannelsAzureApi(params *BatchChannelsParams) (int64, error) {
|
||||
if db.Error != nil {
|
||||
return 0, db.Error
|
||||
}
|
||||
|
||||
if db.RowsAffected > 0 {
|
||||
go ChannelGroup.Load()
|
||||
}
|
||||
return db.RowsAffected, nil
|
||||
}
|
||||
|
||||
@@ -152,10 +158,14 @@ func BatchDelModelChannels(params *BatchChannelsParams) (int64, error) {
|
||||
}
|
||||
|
||||
channel.Models = strings.Join(modelsSlice, ",")
|
||||
channel.Update(false)
|
||||
channel.UpdateRaw(false)
|
||||
count++
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
go ChannelGroup.Load()
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
@@ -187,10 +197,26 @@ func (channel *Channel) Insert() error {
|
||||
return err
|
||||
}
|
||||
err = channel.AddAbilities()
|
||||
|
||||
if err == nil {
|
||||
go ChannelGroup.Load()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel *Channel) Update(overwrite bool) error {
|
||||
|
||||
err := channel.UpdateRaw(overwrite)
|
||||
|
||||
if err == nil {
|
||||
go ChannelGroup.Load()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel *Channel) UpdateRaw(overwrite bool) error {
|
||||
var err error
|
||||
|
||||
if overwrite {
|
||||
@@ -233,6 +259,9 @@ func (channel *Channel) Delete() error {
|
||||
return err
|
||||
}
|
||||
err = channel.DeleteAbilities()
|
||||
if err == nil {
|
||||
go ChannelGroup.Load()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -245,6 +274,11 @@ func UpdateChannelStatusById(id int, status int) {
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
|
||||
go ChannelGroup.Load()
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateChannelUsedQuota(id int, quota int) {
|
||||
|
||||
@@ -3,10 +3,11 @@ package model
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -15,6 +16,21 @@ import (
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func SetupDB() {
|
||||
err := InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
ChannelGroup.Load()
|
||||
|
||||
if viper.GetBool("BATCH_UPDATE_ENABLED") {
|
||||
common.BatchUpdateEnabled = true
|
||||
common.BatchUpdateInterval = common.GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||
InitBatchUpdater()
|
||||
}
|
||||
}
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != common.UserStatusEnabled {
|
||||
@@ -39,8 +55,8 @@ func createRootAccountIfNeed() error {
|
||||
}
|
||||
|
||||
func chooseDB() (*gorm.DB, error) {
|
||||
if os.Getenv("SQL_DSN") != "" {
|
||||
dsn := os.Getenv("SQL_DSN")
|
||||
if viper.IsSet("SQL_DSN") {
|
||||
dsn := viper.GetString("SQL_DSN")
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
@@ -61,8 +77,8 @@ func chooseDB() (*gorm.DB, error) {
|
||||
// Use SQLite
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
|
||||
config := fmt.Sprintf("?_busy_timeout=%d", common.GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000))
|
||||
return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
@@ -70,7 +86,7 @@ func chooseDB() (*gorm.DB, error) {
|
||||
func InitDB() (err error) {
|
||||
db, err := chooseDB()
|
||||
if err == nil {
|
||||
if common.DebugEnabled {
|
||||
if viper.GetBool("debug") {
|
||||
db = db.Debug()
|
||||
}
|
||||
DB = db
|
||||
@@ -78,6 +94,7 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||
|
||||
@@ -27,10 +27,6 @@ func GetOption(key string) (option Option, err error) {
|
||||
func InitOptionMap() {
|
||||
common.OptionMapRWMutex.Lock()
|
||||
common.OptionMap = make(map[string]string)
|
||||
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
|
||||
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
|
||||
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
|
||||
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
|
||||
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
|
||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||
@@ -137,18 +133,14 @@ func UpdateOption(key string, value string) error {
|
||||
}
|
||||
|
||||
var optionIntMap = map[string]*int{
|
||||
"FileUploadPermission": &common.FileUploadPermission,
|
||||
"FileDownloadPermission": &common.FileDownloadPermission,
|
||||
"ImageUploadPermission": &common.ImageUploadPermission,
|
||||
"ImageDownloadPermission": &common.ImageDownloadPermission,
|
||||
"SMTPPort": &common.SMTPPort,
|
||||
"QuotaForNewUser": &common.QuotaForNewUser,
|
||||
"QuotaForInviter": &common.QuotaForInviter,
|
||||
"QuotaForInvitee": &common.QuotaForInvitee,
|
||||
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
||||
"PreConsumedQuota": &common.PreConsumedQuota,
|
||||
"RetryTimes": &common.RetryTimes,
|
||||
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
|
||||
"SMTPPort": &common.SMTPPort,
|
||||
"QuotaForNewUser": &common.QuotaForNewUser,
|
||||
"QuotaForInviter": &common.QuotaForInviter,
|
||||
"QuotaForInvitee": &common.QuotaForInvitee,
|
||||
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
||||
"PreConsumedQuota": &common.PreConsumedQuota,
|
||||
"RetryTimes": &common.RetryTimes,
|
||||
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
|
||||
}
|
||||
|
||||
var optionBoolMap = map[string]*bool{
|
||||
|
||||
Reference in New Issue
Block a user