mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-13 03:43:44 +08:00
refactor: refactor relay part
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -42,7 +43,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set token error: " + err.Error())
|
||||
logger.SysError("Redis set token error: " + err.Error())
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
@@ -62,7 +63,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
logger.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return group, err
|
||||
@@ -80,7 +81,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user quota error: " + err.Error())
|
||||
logger.SysError("Redis set user quota error: " + err.Error())
|
||||
}
|
||||
return quota, err
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
logger.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
return userEnabled, err
|
||||
}
|
||||
@@ -178,13 +179,13 @@ func InitChannelCache() {
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
logger.SysLog("channels synced from database")
|
||||
}
|
||||
|
||||
func SyncChannelCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing channels from database")
|
||||
logger.SysLog("syncing channels from database")
|
||||
InitChannelCache()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -86,11 +89,17 @@ func (channel *Channel) GetBaseURL() string {
|
||||
return *channel.BaseURL
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModelMapping() string {
|
||||
if channel.ModelMapping == nil {
|
||||
return ""
|
||||
func (channel *Channel) GetModelMapping() map[string]string {
|
||||
if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
|
||||
return nil
|
||||
}
|
||||
return *channel.ModelMapping
|
||||
modelMapping := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
|
||||
return nil
|
||||
}
|
||||
return modelMapping
|
||||
}
|
||||
|
||||
func (channel *Channel) Insert() error {
|
||||
@@ -120,7 +129,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||||
ResponseTime: int(responseTime),
|
||||
}).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update response time: " + err.Error())
|
||||
logger.SysError("failed to update response time: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +139,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
|
||||
Balance: balance,
|
||||
}).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update balance: " + err.Error())
|
||||
logger.SysError("failed to update balance: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,11 +156,11 @@ func (channel *Channel) Delete() error {
|
||||
func UpdateChannelStatusById(id int, status int) {
|
||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||
if err != nil {
|
||||
common.SysError("failed to update ability status: " + err.Error())
|
||||
logger.SysError("failed to update ability status: " + err.Error())
|
||||
}
|
||||
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel status: " + err.Error())
|
||||
logger.SysError("failed to update channel status: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +175,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
|
||||
func updateChannelUsedQuota(id int, quota int) {
|
||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update channel used quota: " + err.Error())
|
||||
logger.SysError("failed to update channel used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -44,12 +45,12 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
err := DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to record log: " + err.Error())
|
||||
logger.SysError("failed to record log: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
@@ -68,7 +69,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
err := DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||
logger.Error(ctx, "failed to record log: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,7 +19,7 @@ func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != util.UserStatusEnabled {
|
||||
if err := DB.First(&user).Error; err != nil {
|
||||
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||
hashedPassword, err := common.Password2Hash("123456")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -42,7 +43,7 @@ func chooseDB() (*gorm.DB, error) {
|
||||
dsn := os.Getenv("SQL_DSN")
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
logger.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
return gorm.Open(postgres.New(postgres.Config{
|
||||
DSN: dsn,
|
||||
@@ -52,13 +53,13 @@ func chooseDB() (*gorm.DB, error) {
|
||||
})
|
||||
}
|
||||
// Use MySQL
|
||||
common.SysLog("using MySQL as database")
|
||||
logger.SysLog("using MySQL as database")
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
// Use SQLite
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
logger.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
|
||||
@@ -77,14 +78,14 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||
sqlDB.SetMaxIdleConns(common.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
|
||||
sqlDB.SetMaxOpenConns(common.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
|
||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
|
||||
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
logger.SysLog("database migration started")
|
||||
err = db.AutoMigrate(&Channel{})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -113,11 +114,11 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.SysLog("database migrated")
|
||||
logger.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
return err
|
||||
} else {
|
||||
common.FatalLog(err)
|
||||
logger.FatalLog(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -82,7 +83,7 @@ func loadOptionsFromDatabase() {
|
||||
for _, option := range options {
|
||||
err := updateOptionMap(option.Key, option.Value)
|
||||
if err != nil {
|
||||
common.SysError("failed to update option map: " + err.Error())
|
||||
logger.SysError("failed to update option map: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +91,7 @@ func loadOptionsFromDatabase() {
|
||||
func SyncOptions(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing options from database")
|
||||
logger.SysLog("syncing options from database")
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
)
|
||||
|
||||
type Redemption struct {
|
||||
@@ -75,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if err != nil {
|
||||
return 0, errors.New("兑换失败," + err.Error())
|
||||
}
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", logger.LogQuota(redemption.Quota)))
|
||||
return redemption.Quota, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
@@ -39,7 +40,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
}
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
if err != nil {
|
||||
common.SysError("CacheGetTokenByKey failed: " + err.Error())
|
||||
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("无效的令牌")
|
||||
}
|
||||
@@ -58,7 +59,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
token.Status = common.TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
logger.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌已过期")
|
||||
@@ -69,7 +70,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
token.Status = common.TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
logger.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
@@ -202,7 +203,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
||||
go func() {
|
||||
email, err := GetUserEmail(token.UserId)
|
||||
if err != nil {
|
||||
common.SysError("failed to fetch user email: " + err.Error())
|
||||
logger.SysError("failed to fetch user email: " + err.Error())
|
||||
}
|
||||
prompt := "您的额度即将用尽"
|
||||
if noMoreQuota {
|
||||
@@ -213,7 +214,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
||||
err = common.SendEmail(prompt, email,
|
||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||
if err != nil {
|
||||
common.SysError("failed to send email" + err.Error())
|
||||
logger.SysError("failed to send email" + err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -97,16 +98,16 @@ func (user *User) Insert(inviterId int) error {
|
||||
return result.Error
|
||||
}
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
if inviterId != 0 {
|
||||
if common.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
|
||||
}
|
||||
if common.QuotaForInviter > 0 {
|
||||
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
|
||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
|
||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -232,7 +233,7 @@ func IsAdmin(userId int) bool {
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
|
||||
if err != nil {
|
||||
common.SysError("no such user " + err.Error())
|
||||
logger.SysError("no such user " + err.Error())
|
||||
return false
|
||||
}
|
||||
return user.Role >= common.RoleAdminUser
|
||||
@@ -341,7 +342,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
logger.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,14 +353,14 @@ func updateUserUsedQuota(id int, quota int) {
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota: " + err.Error())
|
||||
logger.SysError("failed to update user used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserRequestCount(id int, count int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user request count: " + err.Error())
|
||||
logger.SysError("failed to update user request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/common/logger"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) {
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
common.SysLog("batch update started")
|
||||
logger.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
store := batchUpdateStores[i]
|
||||
@@ -57,12 +58,12 @@ func batchUpdate() {
|
||||
case BatchUpdateTypeUserQuota:
|
||||
err := increaseUserQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysError("failed to batch update user quota: " + err.Error())
|
||||
logger.SysError("failed to batch update user quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeTokenQuota:
|
||||
err := increaseTokenQuota(key, value)
|
||||
if err != nil {
|
||||
common.SysError("failed to batch update token quota: " + err.Error())
|
||||
logger.SysError("failed to batch update token quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeUsedQuota:
|
||||
updateUserUsedQuota(key, value)
|
||||
@@ -73,5 +74,5 @@ func batchUpdate() {
|
||||
}
|
||||
}
|
||||
}
|
||||
common.SysLog("batch update finished")
|
||||
logger.SysLog("batch update finished")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user