mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: 本地重试
This commit is contained in:
parent
9025756b56
commit
4b60528c5f
@ -5,18 +5,37 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
const KeyRequestBody = "key_request_body"
|
||||||
|
|
||||||
|
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||||
|
requestBody, _ := c.Get(KeyRequestBody)
|
||||||
|
if requestBody != nil {
|
||||||
|
return requestBody.([]byte), nil
|
||||||
|
}
|
||||||
requestBody, err := io.ReadAll(c.Request.Body)
|
requestBody, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = c.Request.Body.Close()
|
_ = c.Request.Body.Close()
|
||||||
|
c.Set(KeyRequestBody, requestBody)
|
||||||
|
return requestBody.([]byte), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||||
|
requestBody, err := GetRequestBody(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = json.Unmarshal(requestBody, &v)
|
err = json.Unmarshal(requestBody, &v)
|
||||||
|
} else {
|
||||||
|
// skip for now
|
||||||
|
// TODO: someday non json request have variant model, we will need to implementation this
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,21 +1,23 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/middleware"
|
||||||
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
var err *dto.OpenAIErrorWithStatusCode
|
var err *dto.OpenAIErrorWithStatusCode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations:
|
case relayconstant.RelayModeImagesGenerations:
|
||||||
@ -29,34 +31,89 @@ func Relay(c *gin.Context) {
|
|||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
if err != nil {
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Relay(c *gin.Context) {
|
||||||
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
|
retryTimes := common.RetryTimes
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
retryTimesStr := c.Query("retry")
|
channelId := c.GetInt("channel_id")
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
group := c.GetString("group")
|
||||||
if retryTimesStr == "" {
|
originalModel := c.GetString("original_model")
|
||||||
retryTimes = common.RetryTimes
|
openaiErr := relayHandler(c, relayMode)
|
||||||
}
|
retryLogStr := fmt.Sprintf("重试:%d", channelId)
|
||||||
if retryTimes > 0 {
|
if openaiErr != nil {
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
go processChannelError(c, channelId, openaiErr)
|
||||||
} else {
|
} else {
|
||||||
if err.StatusCode == http.StatusTooManyRequests {
|
retryTimes = 0
|
||||||
//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
||||||
}
|
}
|
||||||
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
|
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
|
||||||
c.JSON(err.StatusCode, gin.H{
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||||
"error": err.Error,
|
if err != nil {
|
||||||
|
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
channelId = channel.Id
|
||||||
|
retryLogStr += fmt.Sprintf("->%d", channel.Id)
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
|
requestBody, err := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
openaiErr = relayHandler(c, relayMode)
|
||||||
|
if openaiErr != nil {
|
||||||
|
go processChannelError(c, channelId, openaiErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||||
|
|
||||||
|
if openaiErr != nil {
|
||||||
|
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
|
}
|
||||||
|
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||||
|
c.JSON(openaiErr.StatusCode, gin.H{
|
||||||
|
"error": openaiErr.Error,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
channelId := c.GetInt("channel_id")
|
}
|
||||||
|
|
||||||
|
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||||
|
if openaiErr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if retryTimes <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := c.Get("specific_channel_id"); ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if openaiErr.StatusCode/100 == 5 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if openaiErr.LocalError {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if openaiErr.StatusCode/100 == 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
||||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
channelName := c.GetString("channel_name")
|
channelName := c.GetString("channel_name")
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
|
@ -10,6 +10,7 @@ type OpenAIError struct {
|
|||||||
type OpenAIErrorWithStatusCode struct {
|
type OpenAIErrorWithStatusCode struct {
|
||||||
Error OpenAIError `json:"error"`
|
Error OpenAIError `json:"error"`
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
|
LocalError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeneralErrorResponse struct {
|
type GeneralErrorResponse struct {
|
||||||
|
@ -127,7 +127,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("specific_channel_id", parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return
|
||||||
|
@ -23,7 +23,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("channelId")
|
channelId, ok := c.Get("specific_channel_id")
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -131,7 +131,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set("group", userGroup)
|
||||||
if shouldSelectChannel {
|
if shouldSelectChannel {
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||||
@ -147,6 +147,14 @@ func Distribute() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
@ -160,6 +168,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.Set("auto_ban", ban)
|
c.Set("auto_ban", ban)
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
c.Set("original_model", modelName) // for retry
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
@ -175,8 +184,4 @@ func Distribute() func(c *gin.Context) {
|
|||||||
case common.ChannelTypeAli:
|
case common.ChannelTypeAli:
|
||||||
c.Set("plugin", channel.Other)
|
c.Set("plugin", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -52,12 +52,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|||||||
// Randomly choose one
|
// Randomly choose one
|
||||||
weightSum := uint(0)
|
weightSum := uint(0)
|
||||||
for _, ability_ := range abilities {
|
for _, ability_ := range abilities {
|
||||||
weightSum += ability_.Weight
|
weightSum += ability_.Weight + 10
|
||||||
}
|
}
|
||||||
if weightSum == 0 {
|
|
||||||
// All weight is 0, randomly choose one
|
|
||||||
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
|
|
||||||
} else {
|
|
||||||
// Randomly choose one
|
// Randomly choose one
|
||||||
weight := common.GetRandomInt(int(weightSum))
|
weight := common.GetRandomInt(int(weightSum))
|
||||||
for _, ability_ := range abilities {
|
for _, ability_ := range abilities {
|
||||||
@ -68,7 +64,6 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return nil, errors.New("channel not found")
|
return nil, errors.New("channel not found")
|
||||||
}
|
}
|
||||||
|
@ -265,7 +265,7 @@ func SyncChannelCache(frequency int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||||
model = "gpt-4-gizmo-*"
|
model = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
@ -280,15 +280,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
if len(channels) == 0 {
|
if len(channels) == 0 {
|
||||||
return nil, errors.New("channel not found")
|
return nil, errors.New("channel not found")
|
||||||
}
|
}
|
||||||
endIdx := len(channels)
|
|
||||||
// choose by priority
|
uniquePriorities := make(map[int]bool)
|
||||||
firstChannel := channels[0]
|
for _, channel := range channels {
|
||||||
if firstChannel.GetPriority() > 0 {
|
uniquePriorities[int(channel.GetPriority())] = true
|
||||||
for i := range channels {
|
|
||||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
|
||||||
endIdx = i
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
var sortedUniquePriorities []int
|
||||||
|
for priority := range uniquePriorities {
|
||||||
|
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||||||
|
}
|
||||||
|
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||||||
|
|
||||||
|
if retry >= len(uniquePriorities) {
|
||||||
|
retry = len(uniquePriorities) - 1
|
||||||
|
}
|
||||||
|
targetPriority := int64(sortedUniquePriorities[retry])
|
||||||
|
|
||||||
|
// get the priority for the given retry number
|
||||||
|
var targetChannels []*Channel
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.GetPriority() == targetPriority {
|
||||||
|
targetChannels = append(targetChannels, channel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,20 +308,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
smoothingFactor := 10
|
smoothingFactor := 10
|
||||||
// Calculate the total weight of all channels up to endIdx
|
// Calculate the total weight of all channels up to endIdx
|
||||||
totalWeight := 0
|
totalWeight := 0
|
||||||
for _, channel := range channels[:endIdx] {
|
for _, channel := range targetChannels {
|
||||||
totalWeight += channel.GetWeight() + smoothingFactor
|
totalWeight += channel.GetWeight() + smoothingFactor
|
||||||
}
|
}
|
||||||
|
|
||||||
//if totalWeight == 0 {
|
|
||||||
// // If all weights are 0, select a channel randomly
|
|
||||||
// return channels[rand.Intn(endIdx)], nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
// Generate a random value in the range [0, totalWeight)
|
// Generate a random value in the range [0, totalWeight)
|
||||||
randomWeight := rand.Intn(totalWeight)
|
randomWeight := rand.Intn(totalWeight)
|
||||||
|
|
||||||
// Find a channel based on its weight
|
// Find a channel based on its weight
|
||||||
for _, channel := range channels[:endIdx] {
|
for _, channel := range targetChannels {
|
||||||
randomWeight -= channel.GetWeight() + smoothingFactor
|
randomWeight -= channel.GetWeight() + smoothingFactor
|
||||||
if randomWeight < 0 {
|
if randomWeight < 0 {
|
||||||
return channel, nil
|
return channel, nil
|
||||||
|
@ -31,6 +31,7 @@ type RelayInfo struct {
|
|||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
|
@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
modelMap := make(map[string]string)
|
modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if modelMap[textRequest.Model] != "" {
|
if modelMap[textRequest.Model] != "" {
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
textRequest.Model = modelMap[textRequest.Model]
|
||||||
@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
// count messages token error 计算promptTokens错误
|
// count messages token error 计算promptTokens错误
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if sensitiveTrigger {
|
if sensitiveTrigger {
|
||||||
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode)
|
return service.RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|||||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
||||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
|
if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
|
||||||
return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
// 用户额度充足,判断令牌额度是否充足
|
// 用户额度充足,判断令牌额度是否充足
|
||||||
@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return preConsumedQuota, userQuota, nil
|
return preConsumedQuota, userQuota, nil
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
relaymodel "one-api/dto"
|
relaymodel "one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// disable & notify
|
// disable & notify
|
||||||
@ -33,7 +34,28 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
|
|||||||
if statusCode == http.StatusUnauthorized {
|
if statusCode == http.StatusUnauthorized {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
|
switch err.Code {
|
||||||
|
case "invalid_api_key":
|
||||||
|
return true
|
||||||
|
case "account_deactivated":
|
||||||
|
return true
|
||||||
|
case "billing_not_active":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch err.Type {
|
||||||
|
case "insufficient_quota":
|
||||||
|
return true
|
||||||
|
// https://docs.anthropic.com/claude/reference/errors
|
||||||
|
case "authentication_error":
|
||||||
|
return true
|
||||||
|
case "permission_error":
|
||||||
|
return true
|
||||||
|
case "forbidden":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
||||||
|
return true
|
||||||
|
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
openaiErr := OpenAIErrorWrapper(err, code, statusCode)
|
||||||
|
openaiErr.LocalError = true
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||||
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
|
Loading…
Reference in New Issue
Block a user