diff --git a/README.md b/README.md
index cf75255..abb2379 100644
--- a/README.md
+++ b/README.md
@@ -59,6 +59,16 @@
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
+## 渠道重试
+渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。
+如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
+### 缓存设置方法
+1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ + 例子:`MEMORY_CACHE_ENABLED=true`
+
+
## 部署
### 基于 Docker 进行部署
```shell
diff --git a/common/constants.go b/common/constants.go
index 0e4192a..98e6abd 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -117,7 +117,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
-var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
+var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
diff --git a/common/gin.go b/common/gin.go
index ffa1e21..4a909df 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -5,18 +5,37 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
+ "strings"
)
-func UnmarshalBodyReusable(c *gin.Context, v any) error {
+const KeyRequestBody = "key_request_body"
+
+func GetRequestBody(c *gin.Context) ([]byte, error) {
+ requestBody, _ := c.Get(KeyRequestBody)
+ if requestBody != nil {
+ return requestBody.([]byte), nil
+ }
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
- return err
+ return nil, err
}
- err = c.Request.Body.Close()
+ _ = c.Request.Body.Close()
+ c.Set(KeyRequestBody, requestBody)
+ return requestBody.([]byte), nil
+}
+
+func UnmarshalBodyReusable(c *gin.Context, v any) error {
+ requestBody, err := GetRequestBody(c)
if err != nil {
return err
}
- err = json.Unmarshal(requestBody, &v)
+ contentType := c.Request.Header.Get("Content-Type")
+ if strings.HasPrefix(contentType, "application/json") {
+ err = json.Unmarshal(requestBody, &v)
+ } else {
+ // skip for now
+ // TODO: someday non json request have variant model, we will need to implementation this
+ }
if err != nil {
return err
}
diff --git a/common/utils.go b/common/utils.go
index eb6678a..d540c2e 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -236,3 +236,8 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
+
+func RandomSleep() {
+ // Sleep for 0-3000 ms
+ time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
+}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index a4dcfe9..e407193 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -27,7 +27,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
- common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
@@ -60,12 +59,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
if testModel == "" {
- testModel = adaptor.GetModelList()[0]
- meta.UpstreamModelName = testModel
+ if channel.TestModel != nil && *channel.TestModel != "" {
+ testModel = *channel.TestModel
+ } else {
+ testModel = adaptor.GetModelList()[0]
+ }
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
+ common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta, *request)
diff --git a/controller/misc.go b/controller/misc.go
index f15fa6a..ecc1f26 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -123,17 +123,28 @@ func SendEmailVerification(c *gin.Context) {
return
}
if common.EmailDomainRestrictionEnabled {
+ parts := strings.Split(email, "@")
+ localPart := parts[0]
+ domainPart := parts[1]
+
+ containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
allowed := false
for _, domain := range common.EmailDomainWhitelist {
- if strings.HasSuffix(email, "@"+domain) {
+ if domainPart == domain {
allowed = true
break
}
}
- if !allowed {
+ if allowed && !containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
+ "message": "Your email address is allowed.",
+ })
+ return
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
})
return
}
diff --git a/controller/relay.go b/controller/relay.go
index 9f866b8..c6d850d 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -1,21 +1,23 @@
package controller
import (
+ "bytes"
"fmt"
"github.com/gin-gonic/gin"
+ "io"
"log"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/middleware"
+ "one-api/model"
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
- "strconv"
)
-func Relay(c *gin.Context) {
- relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
@@ -29,33 +31,92 @@ func Relay(c *gin.Context) {
default:
err = relay.TextHelper(c)
}
- if err != nil {
- requestId := c.GetString(common.RequestIdKey)
- retryTimesStr := c.Query("retry")
- retryTimes, _ := strconv.Atoi(retryTimesStr)
- if retryTimesStr == "" {
- retryTimes = common.RetryTimes
+ return err
+}
+
+func Relay(c *gin.Context) {
+ relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+ retryTimes := common.RetryTimes
+ requestId := c.GetString(common.RequestIdKey)
+ channelId := c.GetInt("channel_id")
+ group := c.GetString("group")
+ originalModel := c.GetString("original_model")
+ openaiErr := relayHandler(c, relayMode)
+ retryLogStr := fmt.Sprintf("重试:%d", channelId)
+ if openaiErr != nil {
+ go processChannelError(c, channelId, openaiErr)
+ } else {
+ retryTimes = 0
+ }
+ for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
+ channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+ if err != nil {
+ common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+ break
}
- if retryTimes > 0 {
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
- } else {
- if err.StatusCode == http.StatusTooManyRequests {
- //err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
- }
- err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
- c.JSON(err.StatusCode, gin.H{
- "error": err.Error,
- })
+ channelId = channel.Id
+ retryLogStr += fmt.Sprintf("->%d", channel.Id)
+ common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+ middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+ requestBody, err := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ openaiErr = relayHandler(c, relayMode)
+ if openaiErr != nil {
+ go processChannelError(c, channelId, openaiErr)
}
- channelId := c.GetInt("channel_id")
- autoBan := c.GetBool("auto_ban")
- common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
- // https://platform.openai.com/docs/guides/error-codes/api-errors
- if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
- channelId := c.GetInt("channel_id")
- channelName := c.GetString("channel_name")
- service.DisableChannel(channelId, channelName, err.Error.Message)
+ }
+ common.LogInfo(c.Request.Context(), retryLogStr)
+
+ if openaiErr != nil {
+ if openaiErr.StatusCode == http.StatusTooManyRequests {
+ openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
+ openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+ c.JSON(openaiErr.StatusCode, gin.H{
+ "error": openaiErr.Error,
+ })
+ }
+}
+
+func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
+ if openaiErr == nil {
+ return false
+ }
+ if retryTimes <= 0 {
+ return false
+ }
+ if _, ok := c.Get("specific_channel_id"); ok {
+ return false
+ }
+ if openaiErr.StatusCode == http.StatusTooManyRequests {
+ return true
+ }
+ if openaiErr.StatusCode/100 == 5 {
+ // 超时不重试
+ if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
+ return false
+ }
+ return true
+ }
+ if openaiErr.StatusCode == http.StatusBadRequest {
+ return false
+ }
+ if openaiErr.LocalError {
+ return false
+ }
+ if openaiErr.StatusCode/100 == 2 {
+ return false
+ }
+ return true
+}
+
+func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
+ autoBan := c.GetBool("auto_ban")
+ common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
+ if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
+ channelName := c.GetString("channel_name")
+ service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
diff --git a/controller/user.go b/controller/user.go
index b5a9e48..c305cd4 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
+ "sync"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
@@ -739,7 +740,7 @@ func ManageUser(c *gin.Context) {
user.Role = common.RoleCommonUser
}
- if err := user.Update(false); err != nil {
+ if err := user.UpdateAll(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -804,7 +805,11 @@ type topUpRequest struct {
Key string `json:"key"`
}
+var lock = sync.Mutex{}
+
func TopUp(c *gin.Context) {
+ lock.Lock()
+ defer lock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
diff --git a/dto/error.go b/dto/error.go
index e82e051..b347f6a 100644
--- a/dto/error.go
+++ b/dto/error.go
@@ -10,6 +10,7 @@ type OpenAIError struct {
type OpenAIErrorWithStatusCode struct {
Error OpenAIError `json:"error"`
StatusCode int `json:"status_code"`
+ LocalError bool
}
type GeneralErrorResponse struct {
diff --git a/middleware/auth.go b/middleware/auth.go
index fc6098d..67ac701 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -146,7 +146,7 @@ func TokenAuth() func(c *gin.Context) {
}
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
- c.Set("channelId", parts[1])
+ c.Set("specific_channel_id", parts[1])
} else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 10696a9..35cb6df 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -23,7 +23,7 @@ func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
var channel *model.Channel
- channelId, ok := c.Get("channelId")
+ channelId, ok := c.Get("specific_channel_id")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -131,7 +131,7 @@ func Distribute() func(c *gin.Context) {
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if shouldSelectChannel {
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
@@ -147,36 +147,41 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
- c.Set("channel", channel.Type)
- c.Set("channel_id", channel.Id)
- c.Set("channel_name", channel.Name)
- ban := true
- // parse *int to bool
- if channel.AutoBan != nil && *channel.AutoBan == 0 {
- ban = false
- }
- if nil != channel.OpenAIOrganization {
- c.Set("channel_organization", *channel.OpenAIOrganization)
- }
- c.Set("auto_ban", ban)
- c.Set("model_mapping", channel.GetModelMapping())
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- c.Set("base_url", channel.GetBaseURL())
- // TODO: api_version统一
- switch channel.Type {
- case common.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
- //case common.ChannelTypeAIProxyLibrary:
- // c.Set("library_id", channel.Other)
- case common.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
- case common.ChannelTypeAli:
- c.Set("plugin", channel.Other)
- }
+ SetupContextForSelectedChannel(c, channel, modelRequest.Model)
}
}
c.Next()
}
}
+
+func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+ c.Set("channel", channel.Type)
+ c.Set("channel_id", channel.Id)
+ c.Set("channel_name", channel.Name)
+ ban := true
+ // parse *int to bool
+ if channel.AutoBan != nil && *channel.AutoBan == 0 {
+ ban = false
+ }
+ if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
+ c.Set("channel_organization", *channel.OpenAIOrganization)
+ }
+ c.Set("auto_ban", ban)
+ c.Set("model_mapping", channel.GetModelMapping())
+ c.Set("original_model", modelName) // for retry
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+ c.Set("base_url", channel.GetBaseURL())
+ // TODO: api_version统一
+ switch channel.Type {
+ case common.ChannelTypeAzure:
+ c.Set("api_version", channel.Other)
+ case common.ChannelTypeXunfei:
+ c.Set("api_version", channel.Other)
+ //case common.ChannelTypeAIProxyLibrary:
+ // c.Set("library_id", channel.Other)
+ case common.ChannelTypeGemini:
+ c.Set("api_version", channel.Other)
+ case common.ChannelTypeAli:
+ c.Set("plugin", channel.Other)
+ }
+}
diff --git a/model/ability.go b/model/ability.go
index b79978d..f522967 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -52,21 +52,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
- weightSum += ability_.Weight
+ weightSum += ability_.Weight + 10
}
- if weightSum == 0 {
- // All weight is 0, randomly choose one
- channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
- } else {
- // Randomly choose one
- weight := common.GetRandomInt(int(weightSum))
- for _, ability_ := range abilities {
- weight -= int(ability_.Weight)
- //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
- if weight <= 0 {
- channel.Id = ability_.ChannelId
- break
- }
+ // Randomly choose one
+ weight := common.GetRandomInt(int(weightSum))
+ for _, ability_ := range abilities {
+ weight -= int(ability_.Weight) + 10
+ //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
+ if weight <= 0 {
+ channel.Id = ability_.ChannelId
+ break
}
}
} else {
diff --git a/model/cache.go b/model/cache.go
index a0449bc..01245c9 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -289,7 +289,7 @@ func SyncChannelCache(frequency int) {
}
}
-func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
@@ -304,15 +304,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
- endIdx := len(channels)
- // choose by priority
- firstChannel := channels[0]
- if firstChannel.GetPriority() > 0 {
- for i := range channels {
- if channels[i].GetPriority() != firstChannel.GetPriority() {
- endIdx = i
- break
- }
+
+ uniquePriorities := make(map[int]bool)
+ for _, channel := range channels {
+ uniquePriorities[int(channel.GetPriority())] = true
+ }
+ var sortedUniquePriorities []int
+ for priority := range uniquePriorities {
+ sortedUniquePriorities = append(sortedUniquePriorities, priority)
+ }
+ sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
+
+ if retry >= len(uniquePriorities) {
+ retry = len(uniquePriorities) - 1
+ }
+ targetPriority := int64(sortedUniquePriorities[retry])
+
+ // get the priority for the given retry number
+ var targetChannels []*Channel
+ for _, channel := range channels {
+ if channel.GetPriority() == targetPriority {
+ targetChannels = append(targetChannels, channel)
}
}
@@ -320,20 +332,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
- for _, channel := range channels[:endIdx] {
+ for _, channel := range targetChannels {
totalWeight += channel.GetWeight() + smoothingFactor
}
-
- //if totalWeight == 0 {
- // // If all weights are 0, select a channel randomly
- // return channels[rand.Intn(endIdx)], nil
- //}
-
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
- for _, channel := range channels[:endIdx] {
+ for _, channel := range targetChannels {
randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 {
return channel, nil
diff --git a/model/channel.go b/model/channel.go
index b06f578..3e30ad4 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -10,6 +10,7 @@ type Channel struct {
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"`
OpenAIOrganization *string `json:"openai_organization"`
+ TestModel *string `json:"test_model"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
diff --git a/model/redemption.go b/model/redemption.go
index 122661f..00ec76b 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -56,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if common.UsingPostgreSQL {
keyCol = `"key"`
}
-
+ common.RandomSleep()
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
diff --git a/model/user.go b/model/user.go
index 22258d9..aa9060d 100644
--- a/model/user.go
+++ b/model/user.go
@@ -246,6 +246,27 @@ func (user *User) Update(updatePassword bool) error {
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
+ _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
+ }
+ }
+ return err
+}
+
+func (user *User) UpdateAll(updatePassword bool) error {
+ var err error
+ if updatePassword {
+ user.Password, err = common.Password2Hash(user.Password)
+ if err != nil {
+ return err
+ }
+ }
+ newUser := *user
+ DB.First(&user, user.Id)
+ err = DB.Model(user).Select("*").Updates(newUser).Error
+ if err == nil {
+ if common.RedisEnabled {
+ _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
+ _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
}
return err
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 4de8dc0..2b5d3d2 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -34,6 +34,7 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
+ TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
@@ -63,6 +64,7 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
+ TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 27ed9a9..7ae9dd4 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -31,6 +31,7 @@ type RelayInfo struct {
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
+
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
diff --git a/relay/relay-text.go b/relay/relay-text.go
index ff653ff..71a47c2 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
+ return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
// map model name
@@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+ return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
@@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// count messages token error 计算promptTokens错误
if err != nil {
if sensitiveTrigger {
- return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
+ return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
}
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
@@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
- return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
+ return service.RelayErrorHandler(resp)
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
@@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil {
- return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+ return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
- return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
+ return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
- return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+ return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
@@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
if err != nil {
- return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+ return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
return preConsumedQuota, userQuota, nil
@@ -288,11 +288,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
//}
quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
- if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ if quotaDelta != 0 {
+ err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
+ if err != nil {
+ common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ }
}
- err = model.CacheUpdateUserQuota(relayInfo.UserId)
+ err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
diff --git a/service/channel.go b/service/channel.go
index b9a7627..6ce444d 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -6,6 +6,7 @@ import (
"one-api/common"
relaymodel "one-api/dto"
"one-api/model"
+ "strings"
)
// disable & notify
@@ -33,7 +34,28 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
if statusCode == http.StatusUnauthorized {
return true
}
- if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
+ switch err.Code {
+ case "invalid_api_key":
+ return true
+ case "account_deactivated":
+ return true
+ case "billing_not_active":
+ return true
+ }
+ switch err.Type {
+ case "insufficient_quota":
+ return true
+ // https://docs.anthropic.com/claude/reference/errors
+ case "authentication_error":
+ return true
+ case "permission_error":
+ return true
+ case "forbidden":
+ return true
+ }
+ if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
+ return true
+ } else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
return false
diff --git a/service/error.go b/service/error.go
index cda26b3..39eb0f9 100644
--- a/service/error.go
+++ b/service/error.go
@@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
}
}
+func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+ openaiErr := OpenAIErrorWrapper(err, code, statusCode)
+ openaiErr.LocalError = true
+ return openaiErr
+}
+
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
diff --git a/web/.prettierrc.mjs b/web/.prettierrc.mjs
new file mode 100644
index 0000000..ecae84d
--- /dev/null
+++ b/web/.prettierrc.mjs
@@ -0,0 +1 @@
+module.exports = require("@so1ve/prettier-config");
diff --git a/web/.prettierrc.mjs b/web/.prettierrc.mjs
deleted file mode 100644
index 7890fda..0000000
--- a/web/.prettierrc.mjs
+++ /dev/null
@@ -1 +0,0 @@
-module.exports = require("@so1ve/prettier-config");
\ No newline at end of file
diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js
index 804c7f5..5ac6a6c 100644
--- a/web/src/components/LogsTable.js
+++ b/web/src/components/LogsTable.js
@@ -471,10 +471,10 @@ const LogsTable = () => {
});
};
- const refresh = async (localLogType) => {
+ const refresh = async () => {
// setLoading(true);
setActivePage(1);
- await loadLogs(0, pageSize, localLogType);
+ await loadLogs(0, pageSize, logType);
};
const copyText = async (text) => {
@@ -637,7 +637,7 @@ const LogsTable = () => {
style={{ width: 120 }}
onChange={(value) => {
setLogType(parseInt(value));
- refresh(parseInt(value)).then();
+ loadLogs(0, pageSize, parseInt(value));
}}
>