Merge remote-tracking branch 'origin/main'

# Conflicts:
#	controller/log.go
#	controller/relay-audio.go
#	controller/relay-image.go
#	controller/relay-text.go
#	controller/relay.go
#	middleware/distributor.go
#	model/log.go
#	web/src/components/OperationSetting.js
This commit is contained in:
CaIon
2023-09-17 20:59:12 +08:00
30 changed files with 561 additions and 319 deletions

View File

@@ -10,15 +10,16 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err
@@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error {
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
@@ -159,6 +160,17 @@ func InitChannelCache() {
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].Priority > channels[j].Priority
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
@@ -183,6 +195,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
// choose by priority
firstChannel := channels[0]
if firstChannel.Priority > 0 {
return firstChannel, nil
}
idx := rand.Intn(len(channels))
return channels[idx], nil
}

View File

@@ -24,6 +24,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

View File

@@ -1,6 +1,8 @@
package model
import (
"context"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strings"
@@ -19,6 +21,7 @@ type Log struct {
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Channel int `json:"channel" gorm:"default:0"`
}
const (
@@ -51,7 +54,8 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
@@ -66,15 +70,16 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
Channel: channelId,
TokenId: tokenId,
}
err := DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
common.LogError(ctx, "failed to record log: "+err.Error())
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
@@ -96,6 +101,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
@@ -139,7 +147,7 @@ type Stat struct {
Tpm int `json:"tpm"`
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (stat Stat) {
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" {
tx = tx.Where("username = ?", username)
@@ -156,6 +164,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
return stat
}
@@ -180,3 +192,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}