mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
Merge branch 'channel-stream-mode' into refactor-main
This commit is contained in:
@@ -12,13 +12,22 @@ type Ability struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
|
||||
cmd := "`group` = ? and model = ? and enabled = 1"
|
||||
|
||||
if stream {
|
||||
cmd += " and allow_streaming = 1"
|
||||
} else {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
|
||||
cmd += " and allow_non_streaming = 1"
|
||||
}
|
||||
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
} else {
|
||||
err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -160,9 +160,9 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
return GetRandomSatisfiedChannel(group, model, stream)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
@@ -170,6 +170,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
idx := rand.Intn(len(channels))
|
||||
return channels[idx], nil
|
||||
|
||||
var filteredChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) {
|
||||
filteredChannels = append(filteredChannels, channel)
|
||||
}
|
||||
}
|
||||
|
||||
idx := rand.Intn(len(filteredChannels))
|
||||
return filteredChannels[idx], nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -23,6 +25,8 @@ type Channel struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
AllowStreaming bool `json:"allow_streaming"`
|
||||
AllowNonStreaming bool `json:"allow_non_streaming"`
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
@@ -80,7 +84,19 @@ func BatchInsertChannels(channels []Channel) error {
|
||||
|
||||
func (channel *Channel) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(channel).Error
|
||||
// turn channel into a map
|
||||
channelMap := make(map[string]interface{})
|
||||
|
||||
// Convert channel struct to a map
|
||||
channelBytes, err := json.Marshal(channel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(channelBytes, &channelMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.Create(channelMap).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,11 +106,24 @@ func (channel *Channel) Insert() error {
|
||||
|
||||
func (channel *Channel) Update() error {
|
||||
var err error
|
||||
err = DB.Model(channel).Updates(channel).Error
|
||||
// turn channel into a map
|
||||
channelMap := make(map[string]interface{})
|
||||
|
||||
// Convert channel struct to a map
|
||||
channelBytes, err := json.Marshal(channel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
||||
err = json.Unmarshal(channelBytes, &channelMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = DB.Model(channel).Updates(channelMap).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
DB.Model(channel).First(channelMap, "id = ?", channel.Id)
|
||||
err = channel.UpdateAbilities()
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user