package model import ( "one-api/common" "strings" "gorm.io/datatypes" "gorm.io/gorm" ) type Channel struct { Id int `json:"id"` Type int `json:"type" form:"type" gorm:"default:0"` Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"` Status int `json:"status" form:"status" gorm:"default:1"` Name string `json:"name" form:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:1"` CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` Other string `json:"other" form:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models" form:"models"` Group string `json:"group" form:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Proxy *string `json:"proxy" gorm:"type:varchar(255);default:''"` TestModel string `json:"test_model" form:"test_model" gorm:"type:varchar(50);default:''"` Plugin *datatypes.JSONType[PluginType] `json:"plugin" form:"plugin" gorm:"type:json"` } type PluginType map[string]map[string]interface{} var allowedChannelOrderFields = map[string]bool{ "id": true, "name": true, "group": true, "type": true, "status": true, "response_time": true, "balance": true, "priority": true, } type SearchChannelsParams struct { Channel PaginationParams } func GetChannelsList(params *SearchChannelsParams) (*DataResult[Channel], error) { var channels []*Channel db := DB.Omit("key") if params.Type != 0 { db = db.Where("type = ?", params.Type) } if params.Status != 0 { db = db.Where("status = ?", params.Status) } if params.Name != "" { db = db.Where("name LIKE ?", params.Name+"%") } if params.Group != "" { db = db.Where("id IN (SELECT channel_id FROM abilities WHERE "+quotePostgresField("group")+" = ?)", params.Group) } if params.Models != "" { db = db.Where("id IN (SELECT channel_id FROM abilities WHERE model IN (?))", params.Models) } if params.Other != "" { db = db.Where("other LIKE ?", params.Other+"%") } if params.Key != "" { db = db.Where(quotePostgresField("key")+" = ?", params.Key) } if params.TestModel != "" { db = db.Where("test_model LIKE ?", params.TestModel+"%") } return PaginateAndOrder[Channel](db, ¶ms.PaginationParams, &channels, allowedChannelOrderFields) } func GetAllChannels() ([]*Channel, error) { var channels []*Channel err := DB.Order("id desc").Find(&channels).Error return channels, err } func GetChannelById(id int, selectAll bool) (*Channel, error) { channel := Channel{Id: id} var err error = nil err = DB.First(&channel, "id = ?", id).Error return &channel, err } func BatchInsertChannels(channels []Channel) error { var err error err = DB.Omit("UsedQuota").Create(&channels).Error if err != nil { return err } for _, channel_ := range channels { err = channel_.AddAbilities() if err != nil { return err } } go ChannelGroup.Load() return nil } type BatchChannelsParams struct { Value string `json:"value" form:"value" binding:"required"` Ids []int `json:"ids" form:"ids" binding:"required"` } func BatchUpdateChannelsAzureApi(params *BatchChannelsParams) (int64, error) { db := DB.Model(&Channel{}).Where("id IN ?", params.Ids).Update("other", params.Value) if db.Error != nil { return 0, db.Error } if db.RowsAffected > 0 { go ChannelGroup.Load() } return db.RowsAffected, nil } func BatchDelModelChannels(params *BatchChannelsParams) (int64, error) { var count int64 var channels []*Channel err := DB.Select("id, models, "+quotePostgresField("group")).Find(&channels, "id IN ?", params.Ids).Error if err != nil { return 0, err } for _, channel := range channels { modelsSlice := strings.Split(channel.Models, ",") for i, m := range modelsSlice { if m == params.Value { modelsSlice = append(modelsSlice[:i], modelsSlice[i+1:]...) break } } channel.Models = strings.Join(modelsSlice, ",") channel.UpdateRaw(false) count++ } if count > 0 { go ChannelGroup.Load() } return count, nil } func (channel *Channel) GetPriority() int64 { if channel.Priority == nil { return 0 } return *channel.Priority } func (channel *Channel) GetBaseURL() string { if channel.BaseURL == nil { return "" } return *channel.BaseURL } func (channel *Channel) GetModelMapping() string { if channel.ModelMapping == nil { return "" } return *channel.ModelMapping } func (channel *Channel) Insert() error { var err error err = DB.Omit("UsedQuota").Create(channel).Error if err != nil { return err } err = channel.AddAbilities() if err == nil { go ChannelGroup.Load() } return err } func (channel *Channel) Update(overwrite bool) error { err := channel.UpdateRaw(overwrite) if err == nil { go ChannelGroup.Load() } return err } func (channel *Channel) UpdateRaw(overwrite bool) error { var err error if overwrite { err = DB.Model(channel).Select("*").Omit("UsedQuota").Updates(channel).Error } else { err = DB.Model(channel).Omit("UsedQuota").Updates(channel).Error } if err != nil { return err } DB.Model(channel).First(channel, "id = ?", channel.Id) err = channel.UpdateAbilities() return err } func (channel *Channel) UpdateResponseTime(responseTime int64) { err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ TestTime: common.GetTimestamp(), ResponseTime: int(responseTime), }).Error if err != nil { common.SysError("failed to update response time: " + err.Error()) } } func (channel *Channel) UpdateBalance(balance float64) { err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ BalanceUpdatedTime: common.GetTimestamp(), Balance: balance, }).Error if err != nil { common.SysError("failed to update balance: " + err.Error()) } } func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error if err != nil { return err } err = channel.DeleteAbilities() if err == nil { go ChannelGroup.Load() } return err } func (channel *Channel) StatusToStr() string { switch channel.Status { case common.ChannelStatusEnabled: return "启用" case common.ChannelStatusAutoDisabled: return "自动禁用" case common.ChannelStatusManuallyDisabled: return "手动禁用" } return "禁用" } func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { common.SysError("failed to update ability status: " + err.Error()) } err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error if err != nil { common.SysError("failed to update channel status: " + err.Error()) } if err == nil { go ChannelGroup.Load() } } func UpdateChannelUsedQuota(id int, quota int) { if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return } updateChannelUsedQuota(id, quota) } func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) } } func DeleteChannelByStatus(status int64) (int64, error) { result := DB.Where("status = ?", status).Delete(&Channel{}) return result.RowsAffected, result.Error } func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) // 同时删除Ability DB.Where("enabled = ?", false).Delete(&Ability{}) return result.RowsAffected, result.Error } type ChannelStatistics struct { TotalChannels int `json:"total_channels"` Status int `json:"status"` } func GetStatisticsChannel() (statistics []*ChannelStatistics, err error) { err = DB.Table("channels").Select("count(*) as total_channels, status").Group("status").Scan(&statistics).Error return statistics, err }