feat: 更新代码

This commit is contained in:
suziheng
2024-12-24 10:03:50 +08:00
328 changed files with 13287 additions and 4832 deletions

View File

@@ -1,7 +1,10 @@
package model
import (
"context"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
"sort"
"strings"
)
@@ -13,7 +16,7 @@ type Ability struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
@@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
var channelQuery *gorm.DB
if ignoreFirstPriority {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
} else {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
@@ -49,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Enabled: channel.Status == ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)
@@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var models []string
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
if err != nil {
return nil, err
}
sort.Strings(models)
return models, err
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"math/rand"
"sort"
"strconv"
@@ -21,6 +22,7 @@ var (
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
GroupModelsCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -146,13 +148,32 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
if !common.RedisEnabled {
return GetGroupModels(ctx, group)
}
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
if err == nil {
return strings.Split(modelsStr, ","), nil
}
models, err := GetGroupModels(ctx, group)
if err != nil {
return nil, err
}
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set group models error: " + err.Error())
}
return models, nil
}
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
@@ -205,7 +226,7 @@ func SyncChannelCache(frequency int) {
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@@ -227,7 +248,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior
idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
idx = random.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil

View File

@@ -3,13 +3,20 @@ package model
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
@@ -21,7 +28,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` // DEPRECATED: please save config to field Config
Other *string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
@@ -30,6 +37,19 @@ type Channel struct {
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
}
type ChannelConfig struct {
Region string `json:"region,omitempty"`
SK string `json:"sk,omitempty"`
AK string `json:"ak,omitempty"`
UserID string `json:"user_id,omitempty"`
APIVersion string `json:"api_version,omitempty"`
LibraryID string `json:"library_id,omitempty"`
Plugin string `json:"plugin,omitempty"`
VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"`
VertexAIADC string `json:"vertex_ai_adc,omitempty"`
}
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
@@ -39,7 +59,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
case "all":
err = DB.Order("id desc").Find(&channels).Error
case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
@@ -155,20 +175,20 @@ func (channel *Channel) Delete() error {
return err
}
func (channel *Channel) LoadConfig() (map[string]string, error) {
func (channel *Channel) LoadConfig() (ChannelConfig, error) {
var cfg ChannelConfig
if channel.Config == "" {
return nil, nil
return cfg, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
return cfg, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
if err != nil {
logger.SysError("failed to update ability status: " + err.Error())
}
@@ -199,6 +219,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -3,11 +3,11 @@ package model
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)
@@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordTopupLog(userId int, content string, quota int) {
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: helper.GetTimestamp(),
Type: LogTypeTopup,
Content: content,
Quota: quota,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
@@ -138,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -162,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
if username != "" {
tx = tx.Where("username = ?", username)
}

View File

@@ -1,12 +1,14 @@
package model
import (
"database/sql"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@@ -28,13 +30,17 @@ func CreateRootAccountIfNeed() error {
if err != nil {
return err
}
accessToken := random.GetUUID()
if config.InitialRootAccessToken != "" {
accessToken = config.InitialRootAccessToken
}
rootUser := User{
Username: "root",
Password: hashedPassword,
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
Role: RoleRootUser,
Status: UserStatusEnabled,
DisplayName: "Root User",
AccessToken: helper.GetUUID(),
AccessToken: accessToken,
Quota: 500000000000000,
}
DB.Create(&rootUser)
@@ -44,7 +50,7 @@ func CreateRootAccountIfNeed() error {
Id: 1,
UserId: rootUser.Id,
Key: config.InitialRootToken,
Status: common.TokenStatusEnabled,
Status: TokenStatusEnabled,
Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
@@ -59,92 +65,153 @@ func CreateRootAccountIfNeed() error {
}
func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
dsn := os.Getenv(envName)
switch {
case strings.HasPrefix(dsn, "postgres://"):
// Use PostgreSQL
return openPostgreSQL(dsn)
case dsn != "":
// Use MySQL
logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
return openMySQL(dsn)
default:
// Use SQLite
return openSQLite()
}
// Use SQLite
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
}
func openPostgreSQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName)
if err == nil {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
func openMySQL(dsn string) (*gorm.DB, error) {
logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
if !config.IsMasterNode {
return db, err
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
if env.Bool("StartSqlMigration", false) {
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return nil, err
}
logger.SysLog("database migrated")
}
return db, err
} else {
logger.FatalLog(err)
func openSQLite() (*gorm.DB, error) {
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
func InitDB() {
var err error
DB, err = chooseDB("SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize database: " + err.Error())
return
}
return db, err
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
if err = migrateDB(); err != nil {
logger.FatalLog("failed to migrate database: " + err.Error())
return
}
logger.SysLog("database migrated")
}
func migrateDB() error {
var err error
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Token{}); err != nil {
return err
}
if err = DB.AutoMigrate(&User{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Option{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Redemption{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Ability{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Log{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
return nil
}
func InitLogDB() {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
logger.SysLog("using secondary database for table logs")
var err error
LOG_DB, err = chooseDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
return
}
setDBConns(LOG_DB)
if !config.IsMasterNode {
return
}
logger.SysLog("secondary database migration started")
err = migrateLOGDB()
if err != nil {
logger.FatalLog("failed to migrate secondary database: " + err.Error())
return
}
logger.SysLog("secondary database migrated")
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
logger.FatalLog("failed to connect database: " + err.Error())
return nil
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
return sqlDB
}
func closeDB(db *gorm.DB) error {

View File

@@ -1,9 +1,9 @@
package model
import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv"
"strings"
"time"
@@ -28,6 +28,7 @@ func InitOptionMap() {
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
@@ -66,9 +67,9 @@ func InitOptionMap() {
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
@@ -82,7 +83,7 @@ func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value)
option.Value = billingratio.AddNewMissingRatio(option.Value)
}
err := updateOptionMap(option.Key, option.Value)
if err != nil {
@@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) {
config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue
case "OidcEnabled":
config.OidcEnabled = boolValue
case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
@@ -172,6 +175,22 @@ func updateOptionMap(key string, value string) (err error) {
config.GitHubClientId = value
case "GitHubClientSecret":
config.GitHubClientSecret = value
case "LarkClientId":
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "OidcClientId":
config.OidcClientId = value
case "OidcClientSecret":
config.OidcClientSecret = value
case "OidcWellKnown":
config.OidcWellKnown = value
case "OidcAuthorizationEndpoint":
config.OidcAuthorizationEndpoint = value
case "OidcTokenEndpoint":
config.OidcTokenEndpoint = value
case "OidcUserinfoEndpoint":
config.OidcUserinfoEndpoint = value
case "Footer":
config.Footer = value
case "SystemName":
@@ -205,11 +224,11 @@ func updateOptionMap(key string, value string) (err error) {
case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
err = billingratio.UpdateCompletionRatioByJSONString(value)
case "TopUpLink":
config.TopUpLink = value
case "ChatLink":

View File

@@ -8,6 +8,12 @@ import (
"gorm.io/gorm"
)
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
@@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil {
return errors.New("无效的兑换码")
}
if redemption.Status != common.RedemptionCodeStatusEnabled {
if redemption.Status != RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
return err
}
redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
redemption.Status = RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
})

View File

@@ -11,25 +11,34 @@ import (
"gorm.io/gorm"
)
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"type:text"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
query := DB.Where("user_id = ?", userId)
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
@@ -38,7 +47,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
default:
query = query.Order("id desc")
}
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -60,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
return nil, errors.New("令牌验证失败")
}
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
if token.Status == TokenStatusExhausted {
return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id)
} else if token.Status == TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
if token.Status != TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
token.Status = TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -81,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted
token.Status = TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -112,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
return &token, err
}
func (token *Token) Insert() error {
func (t *Token) Insert() error {
var err error
err = DB.Create(token).Error
err = DB.Create(t).Error
return err
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
func (t *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
return err
}
func (token *Token) SelectUpdate() error {
func (t *Token) SelectUpdate() error {
// This can update zero values
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
}
func (token *Token) Delete() error {
func (t *Token) Delete() error {
var err error
err = DB.Delete(token).Error
err = DB.Delete(t).Error
return err
}
func (t *Token) GetModels() string {
if t == nil {
return ""
}
if t.Models == nil {
return ""
}
return *t.Models
}
func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 {
@@ -245,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if err != nil {
return err
}
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)

View File

@@ -8,10 +8,24 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"strings"
)
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
@@ -24,6 +38,8 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -41,21 +57,21 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
@@ -107,8 +123,8 @@ func (user *User) Insert(inviterId int) error {
}
}
user.Quota = config.QuotaForNewUser
user.AccessToken = helper.GetUUID()
user.AffCode = helper.GetRandomString(4)
user.AccessToken = random.GetUUID()
user.AffCode = random.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
@@ -126,6 +142,22 @@ func (user *User) Insert(inviterId int) error {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
// create default token
cleanToken := Token{
UserId: user.Id,
Name: "default",
Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: -1,
UnlimitedQuota: true,
}
result.Error = cleanToken.Insert()
if result.Error != nil {
// do not block
logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
}
return nil
}
@@ -137,9 +169,9 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
if user.Status == common.UserStatusDisabled {
if user.Status == UserStatusDisabled {
blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled {
} else if user.Status == UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error
@@ -151,8 +183,8 @@ func (user *User) Delete() error {
return errors.New("id 为空!")
}
blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
user.Status = common.UserStatusDeleted
user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
user.Status = UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err
}
@@ -176,7 +208,7 @@ func (user *User) ValidateAndFill() (err error) {
}
}
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
return nil
@@ -206,6 +238,22 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByLarkId() error {
if user.LarkId == "" {
return errors.New("lark id 为空!")
}
DB.Where(User{LarkId: user.LarkId}).First(user)
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -234,6 +282,14 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -260,7 +316,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
return user.Role >= RoleAdminUser
}
func IsUserEnabled(userId int) (bool, error) {
@@ -272,7 +328,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
return user.Status == UserStatusEnabled, nil
}
func ValidateAccessToken(token string) (user *User) {
@@ -345,7 +401,7 @@ func decreaseUserQuota(id int, quota int64) (err error) {
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
return email
}