diff --git a/README.md b/README.md index 971871a..b561cda 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 -- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 +- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 ## 部署 ### 部署要求 diff --git a/common/model-ratio.go b/common/model-ratio.go index 68b2d65..3bdd5f7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -32,13 +32,14 @@ var defaultModelRatio = map[string]float64{ "gpt-4-32k": 30, //"gpt-4-32k-0314": 30, //deprecated "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens - "gpt-4o": 2.5, // $0.01 / 1K tokens - "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.01 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens "gpt-4o-mini": 0.075, "gpt-4o-mini-2024-07-18": 0.075, "gpt-4-turbo": 5, // $0.01 / 1K tokens @@ -326,7 +327,7 @@ func GetCompletionRatio(name string) float64 { return 3 } if strings.HasPrefix(name, "gpt-4o") { - if strings.Contains(name, "mini") { + if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" { return 4 } return 3 diff --git a/controller/channel-test.go b/controller/channel-test.go index fe27978..95c4a60 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -240,7 +240,7 @@ func testAllChannels(notify bool) error { } // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { + if !channel.GetAutoBan() { ban = false } diff --git a/controller/relay.go b/controller/relay.go index 0c79015..13fbde0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) - retryTimes := common.RetryTimes requestId := c.GetString(common.RequestIdKey) - channelId := c.GetInt("channel_id") - channelType := c.GetInt("channel_type") - channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString("original_model") - openaiErr := relayHandler(c, relayMode) - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - if openaiErr != nil { - go processChannelError(c, channelId, channelType, channelName, openaiErr) - } else { - retryTimes = 0 - } - for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ { - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + var openaiErr *dto.OpenAIErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, err.Error()) + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) break } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id)) - c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) - middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - openaiErr = relayHandler(c, relayMode) - if openaiErr != nil { - go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + openaiErr = relayRequest(c, relayMode, channel) + + if openaiErr == nil { + return // 成功处理请求,直接返回 + } + + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break } } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c.Request.Context(), retryLogStr) + common.LogInfo(c, retryLogStr) } if openaiErr != nil { @@ -90,7 +82,42 @@ func Relay(c *gin.Context) { } } -func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { +func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relayHandler(c, relayMode) +} + +func addUsedChannel(c *gin.Context, channelId int) { + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) +} + +func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { + if retryCount == 0 { + autoBan := c.GetBool("auto_ban") + autoBanInt := 1 + if !autoBan { + autoBanInt = 0 + } + return &model.Channel{ + Id: c.GetInt("channel_id"), + Type: c.GetInt("channel_type"), + Name: c.GetString("channel_name"), + AutoBan: &autoBanInt, + }, nil + } + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) + if err != nil { + return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) + } + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + return channel, nil +} + +func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { if openaiErr == nil { return false } @@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } if openaiErr.StatusCode == http.StatusBadRequest { + channelType := c.GetInt("channel_type") + if channelType == common.ChannelTypeAnthropic { + return true + } return false } if openaiErr.StatusCode == 408 { @@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } -func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) { - autoBan := c.GetBool("auto_ban") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) +func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously + common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { service.DisableChannel(channelId, channelName, err.Error.Message) } @@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) break } channelId = channel.Id useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) @@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c.Request.Context(), retryLogStr) + common.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { diff --git a/controller/topup.go b/controller/topup.go index 87c68c3..c4b1aa9 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -41,12 +41,12 @@ func GetEpayClient() *epay.Client { return withUrl } -func getPayMoney(amount float64, user model.User) float64 { +func getPayMoney(amount float64, group string) float64 { if !common.DisplayInCurrencyEnabled { amount = amount / common.QuotaPerUnit } // 别问为什么用float64,问就是这么点钱没必要 - topupGroupRatio := common.GetTopupGroupRatio(user.Group) + topupGroupRatio := common.GetTopupGroupRatio(group) if topupGroupRatio == 0 { topupGroupRatio = 1 } @@ -75,8 +75,12 @@ func RequestEpay(c *gin.Context) { } id := c.GetInt("id") - user, _ := model.GetUserById(id, false) - payMoney := getPayMoney(float64(req.Amount), *user) + group, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) + return + } + payMoney := getPayMoney(float64(req.Amount), group) if payMoney < 0.01 { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return @@ -94,6 +98,7 @@ func RequestEpay(c *gin.Context) { returnUrl, _ := url.Parse(constant.ServerAddress + "/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) + tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) client := GetEpayClient() if client == nil { c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) @@ -101,8 +106,8 @@ func RequestEpay(c *gin.Context) { } uri, params, err := client.Purchase(&epay.PurchaseArgs{ Type: payType, - ServiceTradeNo: "A" + tradeNo, - Name: "B" + tradeNo, + ServiceTradeNo: tradeNo, + Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), Device: epay.PC, NotifyUrl: notifyUrl, @@ -120,7 +125,7 @@ func RequestEpay(c *gin.Context) { UserId: id, Amount: amount, Money: payMoney, - TradeNo: "A" + tradeNo, + TradeNo: tradeNo, CreateTime: time.Now().Unix(), Status: "pending", } @@ -232,8 +237,12 @@ func RequestAmount(c *gin.Context) { return } id := c.GetInt("id") - user, _ := model.GetUserById(id, false) - payMoney := getPayMoney(float64(req.Amount), *user) + group, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) + return + } + payMoney := getPayMoney(float64(req.Amount), group) if payMoney <= 0.01 { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return diff --git a/controller/user.go b/controller/user.go index a6798eb..6faec2b 100644 --- a/controller/user.go +++ b/controller/user.go @@ -791,11 +791,11 @@ type topUpRequest struct { Key string `json:"key"` } -var lock = sync.Mutex{} +var topUpLock = sync.Mutex{} func TopUp(c *gin.Context) { - lock.Lock() - defer lock.Unlock() + topUpLock.Lock() + defer topUpLock.Unlock() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { diff --git a/dto/text_request.go b/dto/text_request.go index 2170e71..a804e63 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -7,31 +7,31 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools []ToolCall `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` - Dimensions int `json:"dimensions,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools []ToolCall `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OpenAITools struct { diff --git a/middleware/auth.go b/middleware/auth.go index edd15de..f9a5900 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -143,6 +143,12 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] } token, err := model.ValidateUserToken(key) + if token != nil { + id := c.GetInt("id") + if id == 0 { + c.Set("id", token.UserId) + } + } if err != nil { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return diff --git a/middleware/distributor.go b/middleware/distributor.go index f150b41..1be3b31 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel == nil { return } - c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } - c.Set("auto_ban", ban) + c.Set("auto_ban", channel.GetAutoBan()) c.Set("model_mapping", channel.GetModelMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/model/channel.go b/model/channel.go index 7db3f07..34aae68 100644 --- a/model/channel.go +++ b/model/channel.go @@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { channel.OtherInfo = string(otherInfoBytes) } +func (channel *Channel) GetAutoBan() bool { + if channel.AutoBan == nil { + return false + } + return *channel.AutoBan == 1 +} + func (channel *Channel) Save() error { return DB.Save(channel).Error } @@ -99,16 +106,23 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err // 构造WHERE子句 var whereClause string var args []interface{} - if group != "" { - whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?" - args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%") + if group != "" && group != "null" { + var groupCondition string + if common.UsingMySQL { + groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?` + } else { + // sqlite, PostgreSQL + groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?` + } + whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition + args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%") } else { whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") } // 执行查询 - err := baseQuery.Where(whereClause, args...).Find(&channels).Error + err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error if err != nil { return nil, err } diff --git a/model/log.go b/model/log.go index f907f43..c373d6d 100644 --- a/model/log.go +++ b/model/log.go @@ -201,8 +201,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa tx = tx.Where("created_at <= ?", endTimestamp) } if modelName != "" { - tx = tx.Where("model_name = ?", modelName) - rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName) + tx = tx.Where("model_name like ?", modelName) + rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName) } if channel != 0 { tx = tx.Where("channel_id = ?", channel) diff --git a/model/token.go b/model/token.go index 27907af..272c573 100644 --- a/model/token.go +++ b/model/token.go @@ -51,12 +51,12 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status == common.TokenStatusExhausted { keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") + return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") } else if token.Status == common.TokenStatusExpired { - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if token.Status != common.TokenStatusEnabled { - return nil, errors.New("该令牌状态不可用") + return token, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if !common.RedisEnabled { @@ -66,7 +66,7 @@ func ValidateUserToken(key string) (token *Token, err error) { common.SysError("failed to update token status" + err.Error()) } } - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { @@ -79,7 +79,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) + return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) } return token, nil } diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 4f99a24..fac6b7f 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -3,18 +3,18 @@ package ollama import "one-api/dto" type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - Tools []dto.ToolCall `json:"tools,omitempty"` - ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` + Model string `json:"model,omitempty"` + Messages []dto.Message `json:"messages,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed float64 `json:"seed,omitempty"` + Topp float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + Tools []dto.ToolCall `json:"tools,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` } type OllamaEmbeddingRequest struct { diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index 50abc2e..81eb93c 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -8,7 +8,7 @@ var ModelList = []string{ "gpt-4-32k", "gpt-4-32k-0613", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-vision-preview", - "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 564a7ad..3ed5ee3 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -33,7 +33,7 @@ type RelayInfo struct { } func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := c.GetInt("channel") + channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") @@ -112,7 +112,7 @@ type TaskRelayInfo struct { } func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - channelType := c.GetInt("channel") + channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") diff --git a/relay/relay-audio.go b/relay/relay-audio.go index b2fadcc..5dbc323 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { - return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) } err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { diff --git a/relay/relay-image.go b/relay/relay-image.go index 83c7538..74d6c30 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -125,7 +125,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { quota := int(imageRatio * groupRatio * common.QuotaPerUnit) if userQuota-quota < 0 { - return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest) } adaptor := GetAdaptor(relayInfo.ApiType) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 73ea468..4dd81c5 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -549,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if err != nil { common.SysError("get_channel_null: " + err.Error()) } - if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled { + if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") } } diff --git a/relay/relay-text.go b/relay/relay-text.go index 636be56..93d202d 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -238,9 +238,12 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } - if userQuota <= 0 || userQuota-preConsumedQuota < 0 { + if userQuota <= 0 { return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } + if userQuota-preConsumedQuota < 0 { + return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) + } err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) @@ -253,13 +256,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > 100*preConsumedQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) } } if preConsumedQuota > 0 { diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index f2c7518..2942a0b 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -745,7 +745,8 @@ const ChannelsTable = () => {