Compare commits

...

20 Commits

Author SHA1 Message Date
CaIon
2d849e0dd6 fix: 307本地重试 2024-04-08 14:10:09 +08:00
CaIon
60d7ed3fb5 fix: distributor panic 2024-04-08 13:48:36 +08:00
CaIon
a7cfce24d0 feat: automatically ban channels that exceeded quota 2024-04-07 22:22:27 +08:00
CaIon
34bf8f8945 fix: select channel 2024-04-07 22:08:11 +08:00
CaIon
2d1d1b4631 update go-epay 2024-04-07 14:42:03 +08:00
CaIon
fbdb17022c update README.md 2024-04-06 20:49:34 +08:00
CaIon
497cc32634 update README.md 2024-04-06 20:47:16 +08:00
CaIon
462c328d4b feat: 支持未开启缓存下本地重试 2024-04-06 20:45:18 +08:00
CaIon
257cfc2390 fix: email whitelist check 2024-04-06 17:50:47 +08:00
CaIon
fed1a1d6a3 feat: 超时状态码不重试 2024-04-04 21:21:44 +08:00
CaIon
fc9f8c8e8a fix: add group tag 'unknown' 2024-04-04 21:20:54 +08:00
CaIon
f3f36dafbd chore: 优化按次计费的数据库查询次数 2024-04-04 20:10:30 +08:00
CaIon
aaf3a1f07b fix: GetRandomSatisfiedChannel 2024-04-04 19:37:33 +08:00
CaIon
c040fa229d fix bug 2024-04-04 19:18:00 +08:00
CaIon
1cd1e54be4 feat: 钱包兼容非货币形式显示额度 2024-04-04 18:21:23 +08:00
CaIon
3db64afc7f feat: 钱包兼容非货币形式显示额度 2024-04-04 18:20:38 +08:00
CaIon
bc9cfa5da0 feat: 钱包兼容非货币形式显示额度 2024-04-04 18:18:18 +08:00
CaIon
660b9b3c99 feat: able to set default test model (#138) 2024-04-04 17:29:25 +08:00
CaIon
cdf2087952 update README.md 2024-04-04 16:48:28 +08:00
CaIon
4b60528c5f feat: 本地重试 2024-04-04 16:35:44 +08:00
27 changed files with 530 additions and 238 deletions

View File

@@ -59,6 +59,16 @@
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
```shell ```shell

View File

@@ -55,7 +55,8 @@ var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
var EmailDomainWhitelist = []string{ var EmailDomainWhitelist = []string{
"gmail.com", "gmail.com",
"163.com", "163.com",
@@ -111,7 +112,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)

View File

@@ -5,18 +5,37 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings"
) )
func UnmarshalBodyReusable(c *gin.Context, v any) error { const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body) requestBody, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
return err return nil, err
} }
err = c.Request.Body.Close() _ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil { if err != nil {
return err return err
} }
err = json.Unmarshal(requestBody, &v) contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil { if err != nil {
return err return err
} }

View File

@@ -236,3 +236,8 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2)) return *(*[]byte)(unsafe.Pointer(&tmp2))
} }
func RandomSleep() {
// Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}

View File

@@ -27,7 +27,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
if channel.Type == common.ChannelTypeMidjourney { if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil return errors.New("midjourney channel test is not supported"), nil
} }
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@@ -60,12 +59,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
if testModel == "" { if testModel == "" {
testModel = adaptor.GetModelList()[0] if channel.TestModel != nil && *channel.TestModel != "" {
meta.UpstreamModelName = testModel testModel = *channel.TestModel
} else {
testModel = adaptor.GetModelList()[0]
}
} }
request := buildTestRequest() request := buildTestRequest()
request.Model = testModel request.Model = testModel
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta, *request) adaptor.Init(meta, *request)

View File

@@ -120,12 +120,17 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
parts := strings.Split(email, "@")
if len(parts) != 2 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled { if common.EmailDomainRestrictionEnabled {
parts := strings.Split(email, "@")
localPart := parts[0]
domainPart := parts[1]
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
allowed := false allowed := false
for _, domain := range common.EmailDomainWhitelist { for _, domain := range common.EmailDomainWhitelist {
if domainPart == domain { if domainPart == domain {
@@ -133,13 +138,7 @@ func SendEmailVerification(c *gin.Context) {
break break
} }
} }
if allowed && !containsSpecialSymbols { if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Your email address is allowed.",
})
return
} else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.", "message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
@@ -147,6 +146,17 @@ func SendEmailVerification(c *gin.Context) {
return return
} }
} }
if common.EmailAliasRestrictionEnabled {
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
if containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
})
return
}
}
if model.IsEmailAlreadyTaken(email) { if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -1,21 +1,23 @@
package controller package controller
import ( import (
"bytes"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay" "one-api/relay"
"one-api/relay/constant" "one-api/relay/constant"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv"
) )
func Relay(c *gin.Context) { func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
var err *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case relayconstant.RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
@@ -29,33 +31,95 @@ func Relay(c *gin.Context) {
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }
if err != nil { return err
requestId := c.GetString(common.RequestIdKey) }
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr) func Relay(c *gin.Context) {
if retryTimesStr == "" { relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes = common.RetryTimes retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
retryLogStr := fmt.Sprintf("重试:%d", channelId)
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
} }
if retryTimes > 0 { channelId = channel.Id
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) retryLogStr += fmt.Sprintf("->%d", channel.Id)
} else { common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
if err.StatusCode == http.StatusTooManyRequests { middleware.SetupContextForSelectedChannel(c, channel, originalModel)
//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
} requestBody, err := common.GetRequestBody(c)
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
c.JSON(err.StatusCode, gin.H{ openaiErr = relayHandler(c, relayMode)
"error": err.Error, if openaiErr != nil {
}) go processChannelError(c, channelId, openaiErr)
} }
channelId := c.GetInt("channel_id") }
autoBan := c.GetBool("auto_ban") common.LogInfo(c.Request.Context(), retryLogStr)
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors if openaiErr != nil {
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan { if openaiErr.StatusCode == http.StatusTooManyRequests {
channelId := c.GetInt("channel_id") openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
} }
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
if openaiErr.StatusCode == 307 {
return true
}
if openaiErr.StatusCode/100 == 5 {
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
return false
}
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
return false
}
if openaiErr.LocalError {
return false
}
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
} }
} }

View File

@@ -2,9 +2,10 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/samber/lo" "github.com/samber/lo"
epay "github.com/star-horizon/go-epay"
"log" "log"
"net/url" "net/url"
"one-api/common" "one-api/common"
@@ -30,7 +31,7 @@ func GetEpayClient() *epay.Client {
if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" { if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
return nil return nil
} }
withUrl, err := epay.NewClientWithUrl(&epay.Config{ withUrl, err := epay.NewClient(&epay.Config{
PartnerID: common.EpayId, PartnerID: common.EpayId,
Key: common.EpayKey, Key: common.EpayKey,
}, common.PayAddress) }, common.PayAddress)
@@ -40,31 +41,46 @@ func GetEpayClient() *epay.Client {
return withUrl return withUrl
} }
func GetAmount(count float64, user model.User) float64 { func getPayMoney(amount float64, user model.User) float64 {
if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit
}
// 别问为什么用float64问就是这么点钱没必要 // 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(user.Group) topupGroupRatio := common.GetTopupGroupRatio(user.Group)
if topupGroupRatio == 0 { if topupGroupRatio == 0 {
topupGroupRatio = 1 topupGroupRatio = 1
} }
amount := count * common.Price * topupGroupRatio payMoney := amount * common.Price * topupGroupRatio
return amount return payMoney
}
func getMinTopup() int {
minTopup := common.MinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
}
return minTopup
} }
func RequestEpay(c *gin.Context) { func RequestEpay(c *gin.Context) {
var req EpayRequest var req EpayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if req.Amount < common.MinTopUp { if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10}) c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
payMoney := GetAmount(float64(req.Amount), *user) payMoney := getPayMoney(float64(req.Amount), *user)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
var payType epay.PurchaseType var payType epay.PurchaseType
if req.PaymentMethod == "zfb" { if req.PaymentMethod == "zfb" {
@@ -96,9 +112,13 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
}
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: req.Amount, Amount: amount,
Money: payMoney, Money: payMoney,
TradeNo: "A" + tradeNo, TradeNo: "A" + tradeNo,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
@@ -186,13 +206,13 @@ func EpayNotify(c *gin.Context) {
} }
//user, _ := model.GetUserById(topUp.UserId, false) //user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000 //user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000) err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
if err != nil { if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp) log.Printf("易支付回调更新用户失败: %v", topUp)
return return
} }
log.Printf("易支付回调更新用户成功 %v", topUp) log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*500000), topUp.Money)) model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
} }
} else { } else {
log.Printf("易支付异常回调: %v", verifyInfo) log.Printf("易支付异常回调: %v", verifyInfo)
@@ -206,12 +226,17 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)}) if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
payMoney := GetAmount(float64(req.Amount), *user) payMoney := getPayMoney(float64(req.Amount), *user)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)}) c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
} }

View File

@@ -7,6 +7,7 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"sync"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -789,7 +790,11 @@ type topUpRequest struct {
Key string `json:"key"` Key string `json:"key"`
} }
var lock = sync.Mutex{}
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
lock.Lock()
defer lock.Unlock()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {

View File

@@ -10,6 +10,7 @@ type OpenAIError struct {
type OpenAIErrorWithStatusCode struct { type OpenAIErrorWithStatusCode struct {
Error OpenAIError `json:"error"` Error OpenAIError `json:"error"`
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
LocalError bool
} }
type GeneralErrorResponse struct { type GeneralErrorResponse struct {

8
go.mod
View File

@@ -4,6 +4,7 @@ module one-api
go 1.18 go 1.18
require ( require (
github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
@@ -16,9 +17,8 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.6 github.com/pkoukk/tiktoken-go v0.1.6
github.com/samber/lo v1.38.1 github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0 golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
@@ -65,9 +65,9 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.21.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.1.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect

10
go.sum
View File

@@ -1,3 +1,7 @@
github.com/Calcium-Ion/go-epay v0.0.1 h1:cRCvwNTkPmmLM5od0p4w0cTcYcAPaAVLYr41ujseDcc=
github.com/Calcium-Ion/go-epay v0.0.1/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
@@ -137,6 +141,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI= github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
@@ -175,6 +181,8 @@ golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -182,6 +190,8 @@ golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@@ -127,7 +127,7 @@ func TokenAuth() func(c *gin.Context) {
} }
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("specific_channel_id", parts[1])
} else { } else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return return

View File

@@ -23,7 +23,10 @@ func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("channelId") channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
@@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) {
return return
} }
} else { } else {
shouldSelectChannel := true
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
// check token model mapping // check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled") modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable { if modelLimitEnable {
@@ -128,10 +66,8 @@ func Distribute() func(c *gin.Context) {
} }
} }
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if shouldSelectChannel { if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题 // 如果错误,但是渠道不为空,说明是数据库一致性问题
@@ -147,36 +83,113 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return return
} }
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
} }
} }
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next() c.Next()
} }
} }
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, err
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
return &modelRequest, shouldSelectChannel, nil
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("original_model", modelName) // for retry
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
}

View File

@@ -3,6 +3,7 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
) )
@@ -27,8 +28,7 @@ func GetGroupModels(group string) []string {
return models return models
} }
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func getPriority(group string, model string, retry int) (int, error) {
var abilities []Ability
groupCol := "`group`" groupCol := "`group`"
trueVal := "1" trueVal := "1"
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
@@ -36,9 +36,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
trueVal = "true" trueVal = "true"
} }
var err error = nil var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
// 处理错误
return 0, err
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
// 如果重试次数大于优先级数,则使用最小的优先级
priorityToUse = priorities[len(priorities)-1]
} else {
priorityToUse = priorities[retry]
}
return priorityToUse, nil
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
}
}
return channelQuery
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability
var err error = nil
channelQuery := getChannelQuery(group, model, retry)
if common.UsingSQLite || common.UsingPostgreSQL { if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else { } else {
@@ -52,21 +98,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
// Randomly choose one // Randomly choose one
weightSum := uint(0) weightSum := uint(0)
for _, ability_ := range abilities { for _, ability_ := range abilities {
weightSum += ability_.Weight weightSum += ability_.Weight + 10
} }
if weightSum == 0 { // Randomly choose one
// All weight is 0, randomly choose one weight := common.GetRandomInt(int(weightSum))
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId for _, ability_ := range abilities {
} else { weight -= int(ability_.Weight) + 10
// Randomly choose one //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
weight := common.GetRandomInt(int(weightSum)) if weight <= 0 {
for _, ability_ := range abilities { channel.Id = ability_.ChannelId
weight -= int(ability_.Weight) break
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
} }
} }
} else { } else {

View File

@@ -265,14 +265,14 @@ func SyncChannelCache(frequency int) {
} }
} }
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") { if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*" model = "gpt-4-gizmo-*"
} }
// if memory cache is disabled, get channel directly from database // if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled { if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model) return GetRandomSatisfiedChannel(group, model, retry)
} }
channelSyncLock.RLock() channelSyncLock.RLock()
defer channelSyncLock.RUnlock() defer channelSyncLock.RUnlock()
@@ -280,15 +280,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 { if len(channels) == 0 {
return nil, errors.New("channel not found") return nil, errors.New("channel not found")
} }
endIdx := len(channels)
// choose by priority uniquePriorities := make(map[int]bool)
firstChannel := channels[0] for _, channel := range channels {
if firstChannel.GetPriority() > 0 { uniquePriorities[int(channel.GetPriority())] = true
for i := range channels { }
if channels[i].GetPriority() != firstChannel.GetPriority() { var sortedUniquePriorities []int
endIdx = i for priority := range uniquePriorities {
break sortedUniquePriorities = append(sortedUniquePriorities, priority)
} }
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
if retry >= len(uniquePriorities) {
retry = len(uniquePriorities) - 1
}
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var targetChannels []*Channel
for _, channel := range channels {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
} }
} }
@@ -296,20 +308,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
smoothingFactor := 10 smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx // Calculate the total weight of all channels up to endIdx
totalWeight := 0 totalWeight := 0
for _, channel := range channels[:endIdx] { for _, channel := range targetChannels {
totalWeight += channel.GetWeight() + smoothingFactor totalWeight += channel.GetWeight() + smoothingFactor
} }
//if totalWeight == 0 {
// // If all weights are 0, select a channel randomly
// return channels[rand.Intn(endIdx)], nil
//}
// Generate a random value in the range [0, totalWeight) // Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight) randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight // Find a channel based on its weight
for _, channel := range channels[:endIdx] { for _, channel := range targetChannels {
randomWeight -= channel.GetWeight() + smoothingFactor randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 { if randomWeight < 0 {
return channel, nil return channel, nil

View File

@@ -10,6 +10,7 @@ type Channel struct {
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"` Key string `json:"key" gorm:"not null"`
OpenAIOrganization *string `json:"openai_organization"` OpenAIOrganization *string `json:"openai_organization"`
TestModel *string `json:"test_model"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"` Weight *uint `json:"weight" gorm:"default:0"`

View File

@@ -44,6 +44,7 @@ func InitOptionMap() {
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = "" common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = "" common.OptionMap["SMTPFrom"] = ""
@@ -174,6 +175,8 @@ func updateOptionMap(key string, value string) (err error) {
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled":
common.EmailAliasRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled": case "AutomaticEnableChannelEnabled":

View File

@@ -56,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
keyCol = `"key"` keyCol = `"key"`
} }
common.RandomSleep()
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil { if err != nil {

View File

@@ -31,6 +31,7 @@ type RelayInfo struct {
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")

View File

@@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
textRequest, err := getAndValidateTextRequest(c, relayInfo) textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil { if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
} }
// map model name // map model name
@@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
} }
if modelMap[textRequest.Model] != "" { if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model] textRequest.Model = modelMap[textRequest.Model]
@@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// count messages token error 计算promptTokens错误 // count messages token error 计算promptTokens错误
if err != nil { if err != nil {
if sensitiveTrigger { if sensitiveTrigger {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
} }
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
} }
@@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode) return service.RelayErrorHandler(resp)
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
@@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) { func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota <= 0 || userQuota-preConsumedQuota < 0 { if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota > 100*preConsumedQuota { if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足 // 用户额度充足,判断令牌额度是否充足
@@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
return preConsumedQuota, userQuota, nil return preConsumedQuota, userQuota, nil
@@ -288,11 +288,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//} //}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) if quotaDelta != 0 {
if err != nil { err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
common.LogError(ctx, "error consuming token remain quota: "+err.Error()) if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
} }
err = model.CacheUpdateUserQuota(relayInfo.UserId) err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil { if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error()) common.LogError(ctx, "error update user quota cache: "+err.Error())
} }

View File

@@ -6,6 +6,7 @@ import (
"one-api/common" "one-api/common"
relaymodel "one-api/dto" relaymodel "one-api/dto"
"one-api/model" "one-api/model"
"strings"
) )
// disable & notify // disable & notify
@@ -33,7 +34,30 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
if statusCode == http.StatusUnauthorized { if statusCode == http.StatusUnauthorized {
return true return true
} }
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" { switch err.Code {
case "invalid_api_key":
return true
case "account_deactivated":
return true
case "billing_not_active":
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
} else if strings.HasPrefix(err.Message, "You exceeded your current quota") {
return true return true
} }
return false return false

View File

@@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
} }
} }
func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
openaiErr := OpenAIErrorWrapper(err, code, statusCode)
openaiErr.LocalError = true
return openaiErr
}
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,

View File

@@ -42,6 +42,7 @@ const SystemSetting = () => {
TurnstileSecretKey: '', TurnstileSecretKey: '',
RegisterEnabled: '', RegisterEnabled: '',
EmailDomainRestrictionEnabled: '', EmailDomainRestrictionEnabled: '',
EmailAliasRestrictionEnabled: '',
SMTPSSLEnabled: '', SMTPSSLEnabled: '',
EmailDomainWhitelist: [], EmailDomainWhitelist: [],
// telegram login // telegram login
@@ -99,6 +100,7 @@ const SystemSetting = () => {
case 'TelegramOAuthEnabled': case 'TelegramOAuthEnabled':
case 'TurnstileCheckEnabled': case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled': case 'EmailDomainRestrictionEnabled':
case 'EmailAliasRestrictionEnabled':
case 'SMTPSSLEnabled': case 'SMTPSSLEnabled':
case 'RegisterEnabled': case 'RegisterEnabled':
value = inputs[key] === 'true' ? 'false' : 'true'; value = inputs[key] === 'true' ? 'false' : 'true';
@@ -362,7 +364,7 @@ const SystemSetting = () => {
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Input <Form.Input
label='最低充值数量' label='最低充值美元数量(以美金为单位,如果使用额度请自行换算!)'
placeholder='例如2就是最低充值2$' placeholder='例如2就是最低充值2$'
value={inputs.MinTopUp} value={inputs.MinTopUp}
name='MinTopUp' name='MinTopUp'
@@ -480,6 +482,14 @@ const SystemSetting = () => {
checked={inputs.EmailDomainRestrictionEnabled === 'true'} checked={inputs.EmailDomainRestrictionEnabled === 'true'}
/> />
</Form.Group> </Form.Group>
<Form.Group widths={3}>
<Form.Checkbox
label='启用邮箱别名限制例如ab.cd@gmail.com'
name='EmailAliasRestrictionEnabled'
onChange={handleInputChange}
checked={inputs.EmailAliasRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={2}> <Form.Group widths={2}>
<Form.Dropdown <Form.Dropdown
label='允许的邮箱域名' label='允许的邮箱域名'

View File

@@ -15,14 +15,18 @@ export function renderText(text, limit) {
*/ */
export function renderGroup(group) { export function renderGroup(group) {
if (group === '') { if (group === '') {
return <Tag size='large' key='default'>default</Tag>; return (
<Tag size='large' key='default'>
unknown
</Tag>
);
} }
const tagColors = { const tagColors = {
'vip': 'yellow', vip: 'yellow',
'pro': 'yellow', pro: 'yellow',
'svip': 'red', svip: 'red',
'premium': 'red' premium: 'red',
}; };
const groups = group.split(',').sort(); const groups = group.split(',').sort();
@@ -97,12 +101,29 @@ export function getQuotaPerUnit() {
return quotaPerUnit; return quotaPerUnit;
} }
export function renderUnitWithQuota(quota) {
let quotaPerUnit = localStorage.getItem('quota_per_unit');
quotaPerUnit = parseFloat(quotaPerUnit);
quota = parseFloat(quota);
return quotaPerUnit * quota;
}
export function getQuotaWithUnit(quota, digits = 6) { export function getQuotaWithUnit(quota, digits = 6) {
let quotaPerUnit = localStorage.getItem('quota_per_unit'); let quotaPerUnit = localStorage.getItem('quota_per_unit');
quotaPerUnit = parseFloat(quotaPerUnit); quotaPerUnit = parseFloat(quotaPerUnit);
return (quota / quotaPerUnit).toFixed(digits); return (quota / quotaPerUnit).toFixed(digits);
} }
export function renderQuotaWithAmount(amount) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return '$' + amount;
} else {
return renderUnitWithQuota(amount);
}
}
export function renderQuota(quota, digits = 2) { export function renderQuota(quota, digits = 2) {
let quotaPerUnit = localStorage.getItem('quota_per_unit'); let quotaPerUnit = localStorage.getItem('quota_per_unit');
let displayInCurrency = localStorage.getItem('display_in_currency'); let displayInCurrency = localStorage.getItem('display_in_currency');

View File

@@ -63,6 +63,7 @@ const EditChannel = (props) => {
model_mapping: '', model_mapping: '',
models: [], models: [],
auto_ban: 1, auto_ban: 1,
test_model: '',
groups: ['default'], groups: ['default'],
}; };
const [batch, setBatch] = useState(false); const [batch, setBatch] = useState(false);
@@ -669,6 +670,17 @@ const EditChannel = (props) => {
}} }}
value={inputs.openai_organization} value={inputs.openai_organization}
/> />
<div style={{ marginTop: 10 }}>
<Typography.Text strong>默认测试模型</Typography.Text>
</div>
<Input
name='test_model'
placeholder='不填则为模型列表第一个'
onChange={(value) => {
handleInputChange('test_model', value);
}}
value={inputs.test_model}
/>
<div style={{ marginTop: 10, display: 'flex' }}> <div style={{ marginTop: 10, display: 'flex' }}>
<Space> <Space>
<Checkbox <Checkbox

View File

@@ -1,6 +1,10 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { API, isMobile, showError, showInfo, showSuccess } from '../../helpers'; import { API, isMobile, showError, showInfo, showSuccess } from '../../helpers';
import { renderNumber, renderQuota } from '../../helpers/render'; import {
renderNumber,
renderQuota,
renderQuotaWithAmount,
} from '../../helpers/render';
import { import {
Col, Col,
Layout, Layout,
@@ -12,6 +16,7 @@ import {
Divider, Divider,
Space, Space,
Modal, Modal,
Toast,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import Title from '@douyinfe/semi-ui/lib/es/typography/title'; import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import Text from '@douyinfe/semi-ui/lib/es/typography/text'; import Text from '@douyinfe/semi-ui/lib/es/typography/text';
@@ -20,7 +25,7 @@ import { Link } from 'react-router-dom';
const TopUp = () => { const TopUp = () => {
const [redemptionCode, setRedemptionCode] = useState(''); const [redemptionCode, setRedemptionCode] = useState('');
const [topUpCode, setTopUpCode] = useState(''); const [topUpCode, setTopUpCode] = useState('');
const [topUpCount, setTopUpCount] = useState(10); const [topUpCount, setTopUpCount] = useState(0);
const [minTopupCount, setMinTopUpCount] = useState(1); const [minTopupCount, setMinTopUpCount] = useState(1);
const [amount, setAmount] = useState(0.0); const [amount, setAmount] = useState(0.0);
const [minTopUp, setMinTopUp] = useState(1); const [minTopUp, setMinTopUp] = useState(1);
@@ -76,11 +81,9 @@ const TopUp = () => {
showError('管理员未开启在线充值!'); showError('管理员未开启在线充值!');
return; return;
} }
if (amount === 0) { await getAmount();
await getAmount();
}
if (topUpCount < minTopUp) { if (topUpCount < minTopUp) {
showInfo('充值数量不能小于' + minTopUp); showError('充值数量不能小于' + minTopUp);
return; return;
} }
setPayWay(payment); setPayWay(payment);
@@ -92,7 +95,7 @@ const TopUp = () => {
await getAmount(); await getAmount();
} }
if (topUpCount < minTopUp) { if (topUpCount < minTopUp) {
showInfo('充值数量不能小于' + minTopUp); showError('充值数量不能小于' + minTopUp);
return; return;
} }
setOpen(false); setOpen(false);
@@ -189,7 +192,8 @@ const TopUp = () => {
if (message === 'success') { if (message === 'success') {
setAmount(parseFloat(data)); setAmount(parseFloat(data));
} else { } else {
showError(data); setAmount(0);
Toast.error({ content: '错误:' + data, id: 'getAmount' });
// setTopUpCount(parseInt(res.data.count)); // setTopUpCount(parseInt(res.data.count));
// setAmount(parseInt(data)); // setAmount(parseInt(data));
} }
@@ -222,7 +226,7 @@ const TopUp = () => {
size={'small'} size={'small'}
centered={true} centered={true}
> >
<p>充值数量{topUpCount}$</p> <p>充值数量{topUpCount}</p>
<p>实付金额{renderAmount()}</p> <p>实付金额{renderAmount()}</p>
<p>是否确认充值</p> <p>是否确认充值</p>
</Modal> </Modal>
@@ -274,21 +278,16 @@ const TopUp = () => {
disabled={!enableOnlineTopUp} disabled={!enableOnlineTopUp}
field={'redemptionCount'} field={'redemptionCount'}
label={'实付金额:' + renderAmount()} label={'实付金额:' + renderAmount()}
placeholder={'充值数量,最低' + minTopUp + '$'} placeholder={
'充值数量,最低 ' + renderQuotaWithAmount(minTopUp)
}
name='redemptionCount' name='redemptionCount'
type={'number'} type={'number'}
value={topUpCount} value={topUpCount}
suffix={'$'}
min={minTopUp}
defaultValue={minTopUp}
max={100000}
onChange={async (value) => { onChange={async (value) => {
if (value < 1) { if (value < 1) {
value = 1; value = 1;
} }
if (value > 100000) {
value = 100000;
}
setTopUpCount(value); setTopUpCount(value);
await getAmount(value); await getAmount(value);
}} }}