mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
8 Commits
v0.2.1.0-a
...
v0.2.1.0-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aaf3a1f07b | ||
|
|
c040fa229d | ||
|
|
1cd1e54be4 | ||
|
|
3db64afc7f | ||
|
|
bc9cfa5da0 | ||
|
|
660b9b3c99 | ||
|
|
cdf2087952 | ||
|
|
4b60528c5f |
10
README.md
10
README.md
@@ -59,6 +59,16 @@
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
## 渠道重试
|
||||
渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。
|
||||
如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||
### 缓存设置方法
|
||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||
|
||||
|
||||
## 部署
|
||||
### 基于 Docker 进行部署
|
||||
```shell
|
||||
|
||||
@@ -111,7 +111,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
@@ -5,18 +5,37 @@ import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
const KeyRequestBody = "key_request_body"
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
requestBody, err := GetRequestBody(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -236,3 +236,8 @@ func StringToByteSlice(s string) []byte {
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = &http.Request{
|
||||
@@ -60,12 +59,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
if testModel == "" {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
meta.UpstreamModelName = testModel
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
}
|
||||
}
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
@@ -29,33 +31,88 @@ func Relay(c *gin.Context) {
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
if err != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
retryTimes = common.RetryTimes
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
retryTimes := common.RetryTimes
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
channelId := c.GetInt("channel_id")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
openaiErr := relayHandler(c, relayMode)
|
||||
retryLogStr := fmt.Sprintf("重试:%d", channelId)
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
} else {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
break
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.StatusCode == http.StatusTooManyRequests {
|
||||
//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.Error,
|
||||
})
|
||||
channelId = channel.Id
|
||||
retryLogStr += fmt.Sprintf("->%d", channel.Id)
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
openaiErr = relayHandler(c, relayMode)
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
}
|
||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
return false
|
||||
}
|
||||
if _, ok := c.Get("specific_channel_id"); ok {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 5 {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,31 +40,46 @@ func GetEpayClient() *epay.Client {
|
||||
return withUrl
|
||||
}
|
||||
|
||||
func GetAmount(count float64, user model.User) float64 {
|
||||
func getPayMoney(amount float64, user model.User) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// 别问为什么用float64,问就是这么点钱没必要
|
||||
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
amount := count * common.Price * topupGroupRatio
|
||||
return amount
|
||||
payMoney := amount * common.Price * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
minTopup := common.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
return minTopup
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetAmount(float64(req.Amount), *user)
|
||||
payMoney := getPayMoney(float64(req.Amount), *user)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
var payType epay.PurchaseType
|
||||
if req.PaymentMethod == "zfb" {
|
||||
@@ -96,9 +111,13 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
amount := req.Amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / int(common.QuotaPerUnit)
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: "A" + tradeNo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
@@ -186,13 +205,13 @@ func EpayNotify(c *gin.Context) {
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*500000), topUp.Money))
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
@@ -206,12 +225,17 @@ func RequestAmount(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
|
||||
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetAmount(float64(req.Amount), *user)
|
||||
payMoney := getPayMoney(float64(req.Amount), *user)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -789,7 +790,11 @@ type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
var lock = sync.Mutex{}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ type OpenAIError struct {
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
|
||||
@@ -127,7 +127,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("channelId", parts[1])
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
|
||||
@@ -23,7 +23,7 @@ func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("channelId")
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -131,7 +131,7 @@ func Distribute() func(c *gin.Context) {
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
@@ -147,36 +147,41 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if nil != channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("original_model", modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,21 +52,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
// Randomly choose one
|
||||
weightSum := uint(0)
|
||||
for _, ability_ := range abilities {
|
||||
weightSum += ability_.Weight
|
||||
weightSum += ability_.Weight + 10
|
||||
}
|
||||
if weightSum == 0 {
|
||||
// All weight is 0, randomly choose one
|
||||
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
|
||||
} else {
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight)
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
break
|
||||
}
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight) + 10
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -265,7 +265,7 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
@@ -280,15 +280,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
endIdx := len(channels)
|
||||
// choose by priority
|
||||
firstChannel := channels[0]
|
||||
if firstChannel.GetPriority() > 0 {
|
||||
for i := range channels {
|
||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||
endIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
uniquePriorities := make(map[int]bool)
|
||||
for _, channel := range channels {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
}
|
||||
var sortedUniquePriorities []int
|
||||
for priority := range uniquePriorities {
|
||||
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||||
}
|
||||
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||||
|
||||
if retry >= len(uniquePriorities) {
|
||||
retry = len(uniquePriorities) - 1
|
||||
}
|
||||
targetPriority := int64(sortedUniquePriorities[retry])
|
||||
|
||||
// get the priority for the given retry number
|
||||
var targetChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,20 +308,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
smoothingFactor := 10
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range channels[:endIdx] {
|
||||
for _, channel := range targetChannels {
|
||||
totalWeight += channel.GetWeight() + smoothingFactor
|
||||
}
|
||||
|
||||
//if totalWeight == 0 {
|
||||
// // If all weights are 0, select a channel randomly
|
||||
// return channels[rand.Intn(endIdx)], nil
|
||||
//}
|
||||
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range channels[:endIdx] {
|
||||
for _, channel := range targetChannels {
|
||||
randomWeight -= channel.GetWeight() + smoothingFactor
|
||||
if randomWeight < 0 {
|
||||
return channel, nil
|
||||
|
||||
@@ -10,6 +10,7 @@ type Channel struct {
|
||||
Type int `json:"type" gorm:"default:0"`
|
||||
Key string `json:"key" gorm:"not null"`
|
||||
OpenAIOrganization *string `json:"openai_organization"`
|
||||
TestModel *string `json:"test_model"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Weight *uint `json:"weight" gorm:"default:0"`
|
||||
|
||||
@@ -56,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
common.RandomSleep()
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -31,6 +31,7 @@ type RelayInfo struct {
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
|
||||
@@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
@@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
@@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
if sensitiveTrigger {
|
||||
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
}
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
|
||||
return service.RelayErrorHandler(resp)
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
@@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
@@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// disable & notify
|
||||
@@ -33,7 +34,28 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
|
||||
switch err.Code {
|
||||
case "invalid_api_key":
|
||||
return true
|
||||
case "account_deactivated":
|
||||
return true
|
||||
case "billing_not_active":
|
||||
return true
|
||||
}
|
||||
switch err.Type {
|
||||
case "insufficient_quota":
|
||||
return true
|
||||
// https://docs.anthropic.com/claude/reference/errors
|
||||
case "authentication_error":
|
||||
return true
|
||||
case "permission_error":
|
||||
return true
|
||||
case "forbidden":
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
|
||||
}
|
||||
}
|
||||
|
||||
func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||
openaiErr := OpenAIErrorWrapper(err, code, statusCode)
|
||||
openaiErr.LocalError = true
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
|
||||
@@ -362,7 +362,7 @@ const SystemSetting = () => {
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='最低充值数量'
|
||||
label='最低充值美元数量(以美金为单位,如果使用额度请自行换算!)'
|
||||
placeholder='例如:2,就是最低充值2$'
|
||||
value={inputs.MinTopUp}
|
||||
name='MinTopUp'
|
||||
|
||||
@@ -15,14 +15,18 @@ export function renderText(text, limit) {
|
||||
*/
|
||||
export function renderGroup(group) {
|
||||
if (group === '') {
|
||||
return <Tag size='large' key='default'>default</Tag>;
|
||||
return (
|
||||
<Tag size='large' key='default'>
|
||||
default
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const tagColors = {
|
||||
'vip': 'yellow',
|
||||
'pro': 'yellow',
|
||||
'svip': 'red',
|
||||
'premium': 'red'
|
||||
vip: 'yellow',
|
||||
pro: 'yellow',
|
||||
svip: 'red',
|
||||
premium: 'red',
|
||||
};
|
||||
|
||||
const groups = group.split(',').sort();
|
||||
@@ -97,12 +101,29 @@ export function getQuotaPerUnit() {
|
||||
return quotaPerUnit;
|
||||
}
|
||||
|
||||
export function renderUnitWithQuota(quota) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
quota = parseFloat(quota);
|
||||
return quotaPerUnit * quota;
|
||||
}
|
||||
|
||||
export function getQuotaWithUnit(quota, digits = 6) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
return (quota / quotaPerUnit).toFixed(digits);
|
||||
}
|
||||
|
||||
export function renderQuotaWithAmount(amount) {
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return '$' + amount;
|
||||
} else {
|
||||
return renderUnitWithQuota(amount);
|
||||
}
|
||||
}
|
||||
|
||||
export function renderQuota(quota, digits = 2) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
|
||||
@@ -63,6 +63,7 @@ const EditChannel = (props) => {
|
||||
model_mapping: '',
|
||||
models: [],
|
||||
auto_ban: 1,
|
||||
test_model: '',
|
||||
groups: ['default'],
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
@@ -669,6 +670,17 @@ const EditChannel = (props) => {
|
||||
}}
|
||||
value={inputs.openai_organization}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>默认测试模型:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
name='test_model'
|
||||
placeholder='不填则为模型列表第一个'
|
||||
onChange={(value) => {
|
||||
handleInputChange('test_model', value);
|
||||
}}
|
||||
value={inputs.test_model}
|
||||
/>
|
||||
<div style={{ marginTop: 10, display: 'flex' }}>
|
||||
<Space>
|
||||
<Checkbox
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { API, isMobile, showError, showInfo, showSuccess } from '../../helpers';
|
||||
import { renderNumber, renderQuota } from '../../helpers/render';
|
||||
import {
|
||||
renderNumber,
|
||||
renderQuota,
|
||||
renderQuotaWithAmount,
|
||||
} from '../../helpers/render';
|
||||
import {
|
||||
Col,
|
||||
Layout,
|
||||
@@ -12,6 +16,7 @@ import {
|
||||
Divider,
|
||||
Space,
|
||||
Modal,
|
||||
Toast,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
@@ -20,7 +25,7 @@ import { Link } from 'react-router-dom';
|
||||
const TopUp = () => {
|
||||
const [redemptionCode, setRedemptionCode] = useState('');
|
||||
const [topUpCode, setTopUpCode] = useState('');
|
||||
const [topUpCount, setTopUpCount] = useState(10);
|
||||
const [topUpCount, setTopUpCount] = useState(0);
|
||||
const [minTopupCount, setMinTopUpCount] = useState(1);
|
||||
const [amount, setAmount] = useState(0.0);
|
||||
const [minTopUp, setMinTopUp] = useState(1);
|
||||
@@ -76,11 +81,9 @@ const TopUp = () => {
|
||||
showError('管理员未开启在线充值!');
|
||||
return;
|
||||
}
|
||||
if (amount === 0) {
|
||||
await getAmount();
|
||||
}
|
||||
await getAmount();
|
||||
if (topUpCount < minTopUp) {
|
||||
showInfo('充值数量不能小于' + minTopUp);
|
||||
showError('充值数量不能小于' + minTopUp);
|
||||
return;
|
||||
}
|
||||
setPayWay(payment);
|
||||
@@ -92,7 +95,7 @@ const TopUp = () => {
|
||||
await getAmount();
|
||||
}
|
||||
if (topUpCount < minTopUp) {
|
||||
showInfo('充值数量不能小于' + minTopUp);
|
||||
showError('充值数量不能小于' + minTopUp);
|
||||
return;
|
||||
}
|
||||
setOpen(false);
|
||||
@@ -189,7 +192,8 @@ const TopUp = () => {
|
||||
if (message === 'success') {
|
||||
setAmount(parseFloat(data));
|
||||
} else {
|
||||
showError(data);
|
||||
setAmount(0);
|
||||
Toast.error({ content: '错误:' + data, id: 'getAmount' });
|
||||
// setTopUpCount(parseInt(res.data.count));
|
||||
// setAmount(parseInt(data));
|
||||
}
|
||||
@@ -222,7 +226,7 @@ const TopUp = () => {
|
||||
size={'small'}
|
||||
centered={true}
|
||||
>
|
||||
<p>充值数量:{topUpCount}$</p>
|
||||
<p>充值数量:{topUpCount}</p>
|
||||
<p>实付金额:{renderAmount()}</p>
|
||||
<p>是否确认充值?</p>
|
||||
</Modal>
|
||||
@@ -274,21 +278,16 @@ const TopUp = () => {
|
||||
disabled={!enableOnlineTopUp}
|
||||
field={'redemptionCount'}
|
||||
label={'实付金额:' + renderAmount()}
|
||||
placeholder={'充值数量,最低' + minTopUp + '$'}
|
||||
placeholder={
|
||||
'充值数量,最低 ' + renderQuotaWithAmount(minTopUp)
|
||||
}
|
||||
name='redemptionCount'
|
||||
type={'number'}
|
||||
value={topUpCount}
|
||||
suffix={'$'}
|
||||
min={minTopUp}
|
||||
defaultValue={minTopUp}
|
||||
max={100000}
|
||||
onChange={async (value) => {
|
||||
if (value < 1) {
|
||||
value = 1;
|
||||
}
|
||||
if (value > 100000) {
|
||||
value = 100000;
|
||||
}
|
||||
setTopUpCount(value);
|
||||
await getAmount(value);
|
||||
}}
|
||||
|
||||
Reference in New Issue
Block a user