From d82bd2035466fdf1af40ac1353a2aa984caccd2c Mon Sep 17 00:00:00 2001 From: liuzhifei <2679431923@qq.com> Date: Tue, 13 Aug 2024 10:29:55 +0800 Subject: [PATCH] support log db --- model/log.go | 26 ++++---- model/main.go | 173 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 134 insertions(+), 65 deletions(-) diff --git a/model/log.go b/model/log.go index 1076145..79cc71b 100644 --- a/model/log.go +++ b/model/log.go @@ -38,7 +38,7 @@ const ( ) func GetLogByKey(key string) (logs []*Log, err error) { - err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error + err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error return logs, err } @@ -54,7 +54,7 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { common.SysError("failed to record log: " + err.Error()) } @@ -84,7 +84,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke IsStream: isStream, Other: otherStr, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { common.LogError(ctx, "failed to record log: "+err.Error()) } @@ -98,9 +98,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB + tx = LOG_DB } else { - tx = DB.Where("type = ?", logType) + tx = LOG_DB.Where("type = ?", logType) } if modelName != "" { tx = tx.Where("model_name like ?", modelName) @@ -127,9 +127,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB.Where("user_id = ?", userId) + tx = LOG_DB.Where("user_id = ?", userId) } else { - tx = DB.Where("user_id = ? and type = ?", userId, logType) + tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { tx = tx.Where("model_name like ?", modelName) @@ -157,12 +157,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error + err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error + err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } @@ -173,10 +173,10 @@ type Stat struct { } func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { - tx := DB.Table("logs").Select("sum(quota) quota") + tx := LOG_DB.Table("logs").Select("sum(quota) quota") // 为rpm和tpm创建单独的查询 - rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") + rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") if username != "" { tx = tx.Where("username = ?", username) @@ -215,7 +215,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -236,6 +236,6 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa } func DeleteOldLog(targetTimestamp int64) (int64, error) { - result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error } diff --git a/model/main.go b/model/main.go index a70f21b..01eb6c9 100644 --- a/model/main.go +++ b/model/main.go @@ -15,6 +15,8 @@ import ( var DB *gorm.DB +var LOG_DB *gorm.DB + func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { @@ -38,9 +40,9 @@ func createRootAccountIfNeed() error { return nil } -func chooseDB() (*gorm.DB, error) { - if os.Getenv("SQL_DSN") != "" { - dsn := os.Getenv("SQL_DSN") +func chooseDB(envName string) (*gorm.DB, error) { + dsn := os.Getenv(envName) + if dsn != "" { if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") @@ -52,6 +54,13 @@ func chooseDB() (*gorm.DB, error) { PrepareStmt: true, // precompile SQL }) } + if strings.HasPrefix(dsn, "local") { + common.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) + } // Use MySQL common.SysLog("using MySQL as database") // check parseTime @@ -76,7 +85,7 @@ func chooseDB() (*gorm.DB, error) { } func InitDB() (err error) { - db, err := chooseDB() + db, err := chooseDB("SQL_DSN") if err == nil { if common.DebugEnabled { db = db.Debug() @@ -100,52 +109,7 @@ func InitDB() (err error) { // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded //} common.SysLog("database migration started") - err = db.AutoMigrate(&Channel{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Token{}) - if err != nil { - return err - } - err = db.AutoMigrate(&User{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Option{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Redemption{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Ability{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Log{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Midjourney{}) - if err != nil { - return err - } - err = db.AutoMigrate(&TopUp{}) - if err != nil { - return err - } - err = db.AutoMigrate(&QuotaData{}) - if err != nil { - return err - } - err = db.AutoMigrate(&Task{}) - if err != nil { - return err - } - common.SysLog("database migrated") - err = createRootAccountIfNeed() + err = migrateDB() return err } else { common.FatalLog(err) @@ -153,8 +117,103 @@ func InitDB() (err error) { return err } -func CloseDB() error { - sqlDB, err := DB.DB() +func InitLogDB() (err error) { + if os.Getenv("LOG_SQL_DSN") == "" { + LOG_DB = DB + return + } + db, err := chooseDB("LOG_SQL_DSN") + if err == nil { + if common.DebugEnabled { + db = db.Debug() + } + LOG_DB = db + sqlDB, err := LOG_DB.DB() + if err != nil { + return err + } + sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) + + if !common.IsMasterNode { + return nil + } + //if common.UsingMySQL { + // _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded + // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded + // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded + //} + common.SysLog("database migration started") + err = migrateLOGDB() + return err + } else { + common.FatalLog(err) + } + return err +} + +func migrateDB() error { + err := DB.AutoMigrate(&Channel{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Token{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&User{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Option{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Redemption{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Ability{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Log{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Midjourney{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&TopUp{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&QuotaData{}) + if err != nil { + return err + } + err = DB.AutoMigrate(&Task{}) + if err != nil { + return err + } + common.SysLog("database migrated") + err = createRootAccountIfNeed() + return err +} + +func migrateLOGDB() error { + var err error + if err = LOG_DB.AutoMigrate(&Log{}); err != nil { + return err + } + return nil +} + +func closeDB(db *gorm.DB) error { + sqlDB, err := db.DB() if err != nil { return err } @@ -162,6 +221,16 @@ func CloseDB() error { return err } +func CloseDB() error { + if LOG_DB != DB { + err := closeDB(LOG_DB) + if err != nil { + return err + } + } + return closeDB(DB) +} + var ( lastPingTime time.Time pingMutex sync.Mutex