mirror of
https://github.com/songquanpeng/one-api.git
synced 2026-02-17 19:34:26 +08:00
feat: support postgres
This commit is contained in:
@@ -15,8 +15,16 @@ type Ability struct {
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
|
||||
cmd := "`group` = ? and model = ? and enabled = 1"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
// Make cmd compatible with PostgreSQL
|
||||
cmd = "\"group\" = ? and model = ? and enabled = true"
|
||||
}
|
||||
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
} else {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
|
||||
}
|
||||
|
||||
@@ -21,13 +21,19 @@ var (
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
var token Token
|
||||
whereItem := "`key` = ?"
|
||||
if common.UsingPostgreSQL {
|
||||
// Make cmd compatible with PostgreSQL
|
||||
whereItem = "\"key\" = ?"
|
||||
}
|
||||
|
||||
if !common.RedisEnabled {
|
||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
||||
err := DB.Where(whereItem, key).First(&token).Error
|
||||
return &token, err
|
||||
}
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
||||
err := DB.Where(whereItem, key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@@ -37,7 +38,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
}
|
||||
|
||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
|
||||
whereItem := "id = ? or name LIKE ? or `key` = ?"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
whereItem = "id = ? or name LIKE ? or \"key\" = ?"
|
||||
}
|
||||
|
||||
err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
@@ -55,7 +62,9 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
func GetRandomChannel() (*Channel, error) {
|
||||
channel := Channel{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
if common.UsingPostgreSQL {
|
||||
err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else if common.UsingSQLite {
|
||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else {
|
||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"os"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
@@ -40,7 +42,14 @@ func CountTable(tableName string) (num int64) {
|
||||
|
||||
func InitDB() (err error) {
|
||||
var db *gorm.DB
|
||||
if os.Getenv("SQL_DSN") != "" {
|
||||
if os.Getenv("POSTGRES_DSN") != "" {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
db, err = gorm.Open(postgres.Open(os.Getenv("POSTGRES_DSN")), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
} else if os.Getenv("SQL_DSN") != "" {
|
||||
// Use MySQL
|
||||
common.SysLog("using MySQL as database")
|
||||
db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{
|
||||
|
||||
@@ -3,8 +3,9 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Redemption struct {
|
||||
@@ -51,7 +52,14 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
redemption := &Redemption{}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
|
||||
whereItem := "`key` = ?"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
// Make cmd compatible with PostgreSQL
|
||||
whereItem = "\"key\" = ?"
|
||||
}
|
||||
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(whereItem, key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
|
||||
@@ -267,7 +267,13 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
|
||||
selectItem := "`group`"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
selectItem = "\"group\""
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(selectItem).Find(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user