mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-29 14:46:38 +08:00
* ♻️ refactor: move file directory * ♻️ refactor: move file directory * ♻️ refactor: support multiple config methods * 🔥 del: remove unused code * 💩 refactor: Refactor channel management and synchronization * 💄 improve: add channel website * ✨ feat: allow recording 0 consumption
320 lines
8.2 KiB
Go
320 lines
8.2 KiB
Go
package model
|
|
|
|
import (
|
|
"one-api/common"
|
|
"strings"
|
|
|
|
"gorm.io/datatypes"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Channel struct {
|
|
Id int `json:"id"`
|
|
Type int `json:"type" form:"type" gorm:"default:0"`
|
|
Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"`
|
|
Status int `json:"status" form:"status" gorm:"default:1"`
|
|
Name string `json:"name" form:"name" gorm:"index"`
|
|
Weight *uint `json:"weight" gorm:"default:1"`
|
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
|
ResponseTime int `json:"response_time"` // in milliseconds
|
|
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
|
Other string `json:"other" form:"other"`
|
|
Balance float64 `json:"balance"` // in USD
|
|
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
|
Models string `json:"models" form:"models"`
|
|
Group string `json:"group" form:"group" gorm:"type:varchar(32);default:'default'"`
|
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
|
Proxy *string `json:"proxy" gorm:"type:varchar(255);default:''"`
|
|
TestModel string `json:"test_model" form:"test_model" gorm:"type:varchar(50);default:''"`
|
|
|
|
Plugin *datatypes.JSONType[PluginType] `json:"plugin" form:"plugin" gorm:"type:json"`
|
|
}
|
|
|
|
type PluginType map[string]map[string]interface{}
|
|
|
|
var allowedChannelOrderFields = map[string]bool{
|
|
"id": true,
|
|
"name": true,
|
|
"group": true,
|
|
"type": true,
|
|
"status": true,
|
|
"response_time": true,
|
|
"balance": true,
|
|
"priority": true,
|
|
}
|
|
|
|
type SearchChannelsParams struct {
|
|
Channel
|
|
PaginationParams
|
|
}
|
|
|
|
func GetChannelsList(params *SearchChannelsParams) (*DataResult[Channel], error) {
|
|
var channels []*Channel
|
|
|
|
db := DB.Omit("key")
|
|
|
|
if params.Type != 0 {
|
|
db = db.Where("type = ?", params.Type)
|
|
}
|
|
|
|
if params.Status != 0 {
|
|
db = db.Where("status = ?", params.Status)
|
|
}
|
|
|
|
if params.Name != "" {
|
|
db = db.Where("name LIKE ?", params.Name+"%")
|
|
}
|
|
|
|
if params.Group != "" {
|
|
db = db.Where("id IN (SELECT channel_id FROM abilities WHERE "+quotePostgresField("group")+" = ?)", params.Group)
|
|
}
|
|
|
|
if params.Models != "" {
|
|
db = db.Where("id IN (SELECT channel_id FROM abilities WHERE model IN (?))", params.Models)
|
|
}
|
|
|
|
if params.Other != "" {
|
|
db = db.Where("other LIKE ?", params.Other+"%")
|
|
}
|
|
|
|
if params.Key != "" {
|
|
db = db.Where(quotePostgresField("key")+" = ?", params.Key)
|
|
}
|
|
|
|
if params.TestModel != "" {
|
|
db = db.Where("test_model LIKE ?", params.TestModel+"%")
|
|
}
|
|
|
|
return PaginateAndOrder[Channel](db, ¶ms.PaginationParams, &channels, allowedChannelOrderFields)
|
|
}
|
|
|
|
func GetAllChannels() ([]*Channel, error) {
|
|
var channels []*Channel
|
|
err := DB.Order("id desc").Find(&channels).Error
|
|
return channels, err
|
|
}
|
|
|
|
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
|
channel := Channel{Id: id}
|
|
var err error = nil
|
|
err = DB.First(&channel, "id = ?", id).Error
|
|
|
|
return &channel, err
|
|
}
|
|
|
|
func BatchInsertChannels(channels []Channel) error {
|
|
var err error
|
|
err = DB.Omit("UsedQuota").Create(&channels).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, channel_ := range channels {
|
|
err = channel_.AddAbilities()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
go ChannelGroup.Load()
|
|
return nil
|
|
}
|
|
|
|
type BatchChannelsParams struct {
|
|
Value string `json:"value" form:"value" binding:"required"`
|
|
Ids []int `json:"ids" form:"ids" binding:"required"`
|
|
}
|
|
|
|
func BatchUpdateChannelsAzureApi(params *BatchChannelsParams) (int64, error) {
|
|
db := DB.Model(&Channel{}).Where("id IN ?", params.Ids).Update("other", params.Value)
|
|
if db.Error != nil {
|
|
return 0, db.Error
|
|
}
|
|
|
|
if db.RowsAffected > 0 {
|
|
go ChannelGroup.Load()
|
|
}
|
|
return db.RowsAffected, nil
|
|
}
|
|
|
|
func BatchDelModelChannels(params *BatchChannelsParams) (int64, error) {
|
|
var count int64
|
|
|
|
var channels []*Channel
|
|
err := DB.Select("id, models, "+quotePostgresField("group")).Find(&channels, "id IN ?", params.Ids).Error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
for _, channel := range channels {
|
|
modelsSlice := strings.Split(channel.Models, ",")
|
|
for i, m := range modelsSlice {
|
|
if m == params.Value {
|
|
modelsSlice = append(modelsSlice[:i], modelsSlice[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
|
|
channel.Models = strings.Join(modelsSlice, ",")
|
|
channel.UpdateRaw(false)
|
|
count++
|
|
}
|
|
|
|
if count > 0 {
|
|
go ChannelGroup.Load()
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (channel *Channel) GetPriority() int64 {
|
|
if channel.Priority == nil {
|
|
return 0
|
|
}
|
|
return *channel.Priority
|
|
}
|
|
|
|
func (channel *Channel) GetBaseURL() string {
|
|
if channel.BaseURL == nil {
|
|
return ""
|
|
}
|
|
return *channel.BaseURL
|
|
}
|
|
|
|
func (channel *Channel) GetModelMapping() string {
|
|
if channel.ModelMapping == nil {
|
|
return ""
|
|
}
|
|
return *channel.ModelMapping
|
|
}
|
|
|
|
func (channel *Channel) Insert() error {
|
|
var err error
|
|
err = DB.Omit("UsedQuota").Create(channel).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = channel.AddAbilities()
|
|
|
|
if err == nil {
|
|
go ChannelGroup.Load()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (channel *Channel) Update(overwrite bool) error {
|
|
|
|
err := channel.UpdateRaw(overwrite)
|
|
|
|
if err == nil {
|
|
go ChannelGroup.Load()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (channel *Channel) UpdateRaw(overwrite bool) error {
|
|
var err error
|
|
|
|
if overwrite {
|
|
err = DB.Model(channel).Select("*").Omit("UsedQuota").Updates(channel).Error
|
|
} else {
|
|
err = DB.Model(channel).Omit("UsedQuota").Updates(channel).Error
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
|
err = channel.UpdateAbilities()
|
|
return err
|
|
}
|
|
|
|
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
|
TestTime: common.GetTimestamp(),
|
|
ResponseTime: int(responseTime),
|
|
}).Error
|
|
if err != nil {
|
|
common.SysError("failed to update response time: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func (channel *Channel) UpdateBalance(balance float64) {
|
|
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
|
BalanceUpdatedTime: common.GetTimestamp(),
|
|
Balance: balance,
|
|
}).Error
|
|
if err != nil {
|
|
common.SysError("failed to update balance: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func (channel *Channel) Delete() error {
|
|
var err error
|
|
err = DB.Delete(channel).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = channel.DeleteAbilities()
|
|
if err == nil {
|
|
go ChannelGroup.Load()
|
|
}
|
|
return err
|
|
}
|
|
|
|
func UpdateChannelStatusById(id int, status int) {
|
|
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
|
if err != nil {
|
|
common.SysError("failed to update ability status: " + err.Error())
|
|
}
|
|
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
|
if err != nil {
|
|
common.SysError("failed to update channel status: " + err.Error())
|
|
}
|
|
|
|
if err == nil {
|
|
|
|
go ChannelGroup.Load()
|
|
}
|
|
}
|
|
|
|
func UpdateChannelUsedQuota(id int, quota int) {
|
|
if common.BatchUpdateEnabled {
|
|
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
|
return
|
|
}
|
|
updateChannelUsedQuota(id, quota)
|
|
}
|
|
|
|
func updateChannelUsedQuota(id int, quota int) {
|
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
|
if err != nil {
|
|
common.SysError("failed to update channel used quota: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func DeleteChannelByStatus(status int64) (int64, error) {
|
|
result := DB.Where("status = ?", status).Delete(&Channel{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
func DeleteDisabledChannel() (int64, error) {
|
|
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
|
// 同时删除Ability
|
|
DB.Where("enabled = ?", false).Delete(&Ability{})
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
type ChannelStatistics struct {
|
|
TotalChannels int `json:"total_channels"`
|
|
Status int `json:"status"`
|
|
}
|
|
|
|
func GetStatisticsChannel() (statistics []*ChannelStatistics, err error) {
|
|
err = DB.Table("channels").Select("count(*) as total_channels, status").Group("status").Scan(&statistics).Error
|
|
return statistics, err
|
|
}
|