Merge remote-tracking branch 'upstream/main'

This commit is contained in:
ckt1031
2023-09-20 23:28:55 +08:00
8 changed files with 55 additions and 28 deletions

View File

@@ -13,7 +13,7 @@ type Ability struct {
Enabled bool `json:"enabled"`
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
Priority int64 `json:"priority" gorm:"bigint;default:0"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
@@ -33,10 +33,12 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe
cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmd, group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = DB.Where(cmd, group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
err = DB.Where(cmd, group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
err = channelQuery.Order("RAND()").First(&ability).Error
}
if err != nil {
return nil, err

View File

@@ -171,7 +171,7 @@ func InitChannelCache() {
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].Priority > channels[j].Priority
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
@@ -209,12 +209,19 @@ func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*C
}
}
endIdx := len(channels)
// choose by priority
firstChannel := filteredChannels[0]
if firstChannel.Priority > 0 {
return firstChannel, nil
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(len(filteredChannels))
idx := rand.Intn(endIdx)
return filteredChannels[idx], nil
}

View File

@@ -16,17 +16,17 @@ type Channel struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"`
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
Priority int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -90,6 +90,27 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
}
return *channel.Priority
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
}
return *channel.ModelMapping
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error

View File

@@ -135,7 +135,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select("sum(quota)")
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -159,7 +159,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}