feat: support postgres

This commit is contained in:
ckt1031
2023-07-27 14:47:45 +08:00
parent c134604cee
commit bba49c959e
9 changed files with 79 additions and 15 deletions

View File

@@ -15,8 +15,16 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
cmd := "`group` = ? and model = ? and enabled = 1"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
cmd = "\"group\" = ? and model = ? and enabled = true"
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
}

View File

@@ -21,13 +21,19 @@ var (
func CacheGetTokenByKey(key string) (*Token, error) {
var token Token
whereItem := "`key` = ?"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
whereItem = "\"key\" = ?"
}
if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(whereItem, key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(whereItem, key).First(&token).Error
if err != nil {
return nil, err
}

View File

@@ -1,8 +1,9 @@
package model
import (
"gorm.io/gorm"
"one-api/common"
"gorm.io/gorm"
)
type Channel struct {
@@ -37,7 +38,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
whereItem := "id = ? or name LIKE ? or `key` = ?"
if common.UsingPostgreSQL {
whereItem = "id = ? or name LIKE ? or \"key\" = ?"
}
err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -55,7 +62,9 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
if common.UsingPostgreSQL {
err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error

View File

@@ -1,11 +1,13 @@
package model
import (
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"os"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var DB *gorm.DB
@@ -40,7 +42,14 @@ func CountTable(tableName string) (num int64) {
func InitDB() (err error) {
var db *gorm.DB
if os.Getenv("SQL_DSN") != "" {
if os.Getenv("POSTGRES_DSN") != "" {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
db, err = gorm.Open(postgres.Open(os.Getenv("POSTGRES_DSN")), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} else if os.Getenv("SQL_DSN") != "" {
// Use MySQL
common.SysLog("using MySQL as database")
db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{

View File

@@ -3,8 +3,9 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"gorm.io/gorm"
)
type Redemption struct {
@@ -51,7 +52,14 @@ func Redeem(key string, userId int) (quota int, err error) {
redemption := &Redemption{}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
whereItem := "`key` = ?"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
whereItem = "\"key\" = ?"
}
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(whereItem, key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@@ -267,7 +267,13 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
selectItem := "`group`"
if common.UsingPostgreSQL {
selectItem = "\"group\""
}
err = DB.Model(&User{}).Where("id = ?", id).Select(selectItem).Find(&group).Error
return group, err
}