diff --git a/README.md b/README.md index cf75255..abb2379 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,16 @@ 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 +## 渠道重试 +渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。 +如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 +### 缓存设置方法 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` +2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`MEMORY_CACHE_ENABLED=true` + + ## 部署 ### 基于 Docker 进行部署 ```shell diff --git a/common/constants.go b/common/constants.go index 0e4192a..98e6abd 100644 --- a/common/constants.go +++ b/common/constants.go @@ -117,7 +117,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) diff --git a/common/gin.go b/common/gin.go index ffa1e21..4a909df 100644 --- a/common/gin.go +++ b/common/gin.go @@ -5,18 +5,37 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "io" + "strings" ) -func UnmarshalBodyReusable(c *gin.Context, v any) error { +const KeyRequestBody = "key_request_body" + +func GetRequestBody(c *gin.Context) ([]byte, error) { + requestBody, _ := c.Get(KeyRequestBody) + if requestBody != nil { + return requestBody.([]byte), nil + } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { - return err + return nil, err } - err = c.Request.Body.Close() + _ = c.Request.Body.Close() + c.Set(KeyRequestBody, requestBody) + return requestBody.([]byte), nil +} + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := GetRequestBody(c) if err != nil { return err } - err = json.Unmarshal(requestBody, &v) + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + err = json.Unmarshal(requestBody, &v) + } else { + // skip for now + // TODO: someday non json request have variant model, we will need to implementation this + } if err != nil { return err } diff --git a/common/utils.go b/common/utils.go index eb6678a..d540c2e 100644 --- a/common/utils.go +++ b/common/utils.go @@ -236,3 +236,8 @@ func StringToByteSlice(s string) []byte { tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} return *(*[]byte)(unsafe.Pointer(&tmp2)) } + +func RandomSleep() { + // Sleep for 0-3000 ms + time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) +} diff --git a/controller/channel-test.go b/controller/channel-test.go index a4dcfe9..e407193 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -27,7 +27,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } - common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -60,12 +59,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } if testModel == "" { - testModel = adaptor.GetModelList()[0] - meta.UpstreamModelName = testModel + if channel.TestModel != nil && *channel.TestModel != "" { + testModel = *channel.TestModel + } else { + testModel = adaptor.GetModelList()[0] + } } request := buildTestRequest() request.Model = testModel meta.UpstreamModelName = testModel + common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) adaptor.Init(meta, *request) diff --git a/controller/misc.go b/controller/misc.go index f15fa6a..ecc1f26 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -123,17 +123,28 @@ func SendEmailVerification(c *gin.Context) { return } if common.EmailDomainRestrictionEnabled { + parts := strings.Split(email, "@") + localPart := parts[0] + domainPart := parts[1] + + containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1 allowed := false for _, domain := range common.EmailDomainWhitelist { - if strings.HasSuffix(email, "@"+domain) { + if domainPart == domain { allowed = true break } } - if !allowed { + if allowed && !containsSpecialSymbols { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中", + "message": "Your email address is allowed.", + }) + return + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.", }) return } diff --git a/controller/relay.go b/controller/relay.go index 9f866b8..c6d850d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,21 +1,23 @@ package controller import ( + "bytes" "fmt" "github.com/gin-gonic/gin" + "io" "log" "net/http" "one-api/common" "one-api/dto" + "one-api/middleware" + "one-api/model" "one-api/relay" "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/service" - "strconv" ) -func Relay(c *gin.Context) { - relayMode := constant.Path2RelayMode(c.Request.URL.Path) +func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: @@ -29,33 +31,92 @@ func Relay(c *gin.Context) { default: err = relay.TextHelper(c) } - if err != nil { - requestId := c.GetString(common.RequestIdKey) - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes + return err +} + +func Relay(c *gin.Context) { + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + retryTimes := common.RetryTimes + requestId := c.GetString(common.RequestIdKey) + channelId := c.GetInt("channel_id") + group := c.GetString("group") + originalModel := c.GetString("original_model") + openaiErr := relayHandler(c, relayMode) + retryLogStr := fmt.Sprintf("重试:%d", channelId) + if openaiErr != nil { + go processChannelError(c, channelId, openaiErr) + } else { + retryTimes = 0 + } + for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ { + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + if err != nil { + common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + break } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.StatusCode == http.StatusTooManyRequests { - //err.Error.Message = "当前分组上游负载已饱和,请稍后再试" - } - err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) - c.JSON(err.StatusCode, gin.H{ - "error": err.Error, - }) + channelId = channel.Id + retryLogStr += fmt.Sprintf("->%d", channel.Id) + common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + openaiErr = relayHandler(c, relayMode) + if openaiErr != nil { + go processChannelError(c, channelId, openaiErr) } - channelId := c.GetInt("channel_id") - autoBan := c.GetBool("auto_ban") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message)) - // https://platform.openai.com/docs/guides/error-codes/api-errors - if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan { - channelId := c.GetInt("channel_id") - channelName := c.GetString("channel_name") - service.DisableChannel(channelId, channelName, err.Error.Message) + } + common.LogInfo(c.Request.Context(), retryLogStr) + + if openaiErr != nil { + if openaiErr.StatusCode == http.StatusTooManyRequests { + openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } + openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) + c.JSON(openaiErr.StatusCode, gin.H{ + "error": openaiErr.Error, + }) + } +} + +func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { + if openaiErr == nil { + return false + } + if retryTimes <= 0 { + return false + } + if _, ok := c.Get("specific_channel_id"); ok { + return false + } + if openaiErr.StatusCode == http.StatusTooManyRequests { + return true + } + if openaiErr.StatusCode/100 == 5 { + // 超时不重试 + if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 { + return false + } + return true + } + if openaiErr.StatusCode == http.StatusBadRequest { + return false + } + if openaiErr.LocalError { + return false + } + if openaiErr.StatusCode/100 == 2 { + return false + } + return true +} + +func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) { + autoBan := c.GetBool("auto_ban") + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message)) + if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan { + channelName := c.GetString("channel_name") + service.DisableChannel(channelId, channelName, err.Error.Message) } } diff --git a/controller/user.go b/controller/user.go index b5a9e48..c305cd4 100644 --- a/controller/user.go +++ b/controller/user.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/model" "strconv" + "sync" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -739,7 +740,7 @@ func ManageUser(c *gin.Context) { user.Role = common.RoleCommonUser } - if err := user.Update(false); err != nil { + if err := user.UpdateAll(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -804,7 +805,11 @@ type topUpRequest struct { Key string `json:"key"` } +var lock = sync.Mutex{} + func TopUp(c *gin.Context) { + lock.Lock() + defer lock.Unlock() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { diff --git a/dto/error.go b/dto/error.go index e82e051..b347f6a 100644 --- a/dto/error.go +++ b/dto/error.go @@ -10,6 +10,7 @@ type OpenAIError struct { type OpenAIErrorWithStatusCode struct { Error OpenAIError `json:"error"` StatusCode int `json:"status_code"` + LocalError bool } type GeneralErrorResponse struct { diff --git a/middleware/auth.go b/middleware/auth.go index fc6098d..67ac701 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -146,7 +146,7 @@ func TokenAuth() func(c *gin.Context) { } if len(parts) > 1 { if model.IsAdmin(token.UserId) { - c.Set("channelId", parts[1]) + c.Set("specific_channel_id", parts[1]) } else { abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return diff --git a/middleware/distributor.go b/middleware/distributor.go index 10696a9..35cb6df 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -23,7 +23,7 @@ func Distribute() func(c *gin.Context) { return func(c *gin.Context) { userId := c.GetInt("id") var channel *model.Channel - channelId, ok := c.Get("channelId") + channelId, ok := c.Get("specific_channel_id") if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -131,7 +131,7 @@ func Distribute() func(c *gin.Context) { userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) if shouldSelectChannel { - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) // 如果错误,但是渠道不为空,说明是数据库一致性问题 @@ -147,36 +147,41 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) return } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } - if nil != channel.OpenAIOrganization { - c.Set("channel_organization", *channel.OpenAIOrganization) - } - c.Set("auto_ban", ban) - c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - // TODO: api_version统一 - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) - } + SetupContextForSelectedChannel(c, channel, modelRequest.Model) } } c.Next() } } + +func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { + c.Set("channel_organization", *channel.OpenAIOrganization) + } + c.Set("auto_ban", ban) + c.Set("model_mapping", channel.GetModelMapping()) + c.Set("original_model", modelName) // for retry + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.GetBaseURL()) + // TODO: api_version统一 + switch channel.Type { + case common.ChannelTypeAzure: + c.Set("api_version", channel.Other) + case common.ChannelTypeXunfei: + c.Set("api_version", channel.Other) + //case common.ChannelTypeAIProxyLibrary: + // c.Set("library_id", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) + } +} diff --git a/model/ability.go b/model/ability.go index b79978d..f522967 100644 --- a/model/ability.go +++ b/model/ability.go @@ -52,21 +52,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { // Randomly choose one weightSum := uint(0) for _, ability_ := range abilities { - weightSum += ability_.Weight + weightSum += ability_.Weight + 10 } - if weightSum == 0 { - // All weight is 0, randomly choose one - channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId - } else { - // Randomly choose one - weight := common.GetRandomInt(int(weightSum)) - for _, ability_ := range abilities { - weight -= int(ability_.Weight) - //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight) - if weight <= 0 { - channel.Id = ability_.ChannelId - break - } + // Randomly choose one + weight := common.GetRandomInt(int(weightSum)) + for _, ability_ := range abilities { + weight -= int(ability_.Weight) + 10 + //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight) + if weight <= 0 { + channel.Id = ability_.ChannelId + break } } } else { diff --git a/model/cache.go b/model/cache.go index a0449bc..01245c9 100644 --- a/model/cache.go +++ b/model/cache.go @@ -289,7 +289,7 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { if strings.HasPrefix(model, "gpt-4-gizmo") { model = "gpt-4-gizmo-*" } @@ -304,15 +304,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } - endIdx := len(channels) - // choose by priority - firstChannel := channels[0] - if firstChannel.GetPriority() > 0 { - for i := range channels { - if channels[i].GetPriority() != firstChannel.GetPriority() { - endIdx = i - break - } + + uniquePriorities := make(map[int]bool) + for _, channel := range channels { + uniquePriorities[int(channel.GetPriority())] = true + } + var sortedUniquePriorities []int + for priority := range uniquePriorities { + sortedUniquePriorities = append(sortedUniquePriorities, priority) + } + sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities))) + + if retry >= len(uniquePriorities) { + retry = len(uniquePriorities) - 1 + } + targetPriority := int64(sortedUniquePriorities[retry]) + + // get the priority for the given retry number + var targetChannels []*Channel + for _, channel := range channels { + if channel.GetPriority() == targetPriority { + targetChannels = append(targetChannels, channel) } } @@ -320,20 +332,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error smoothingFactor := 10 // Calculate the total weight of all channels up to endIdx totalWeight := 0 - for _, channel := range channels[:endIdx] { + for _, channel := range targetChannels { totalWeight += channel.GetWeight() + smoothingFactor } - - //if totalWeight == 0 { - // // If all weights are 0, select a channel randomly - // return channels[rand.Intn(endIdx)], nil - //} - // Generate a random value in the range [0, totalWeight) randomWeight := rand.Intn(totalWeight) // Find a channel based on its weight - for _, channel := range channels[:endIdx] { + for _, channel := range targetChannels { randomWeight -= channel.GetWeight() + smoothingFactor if randomWeight < 0 { return channel, nil diff --git a/model/channel.go b/model/channel.go index b06f578..3e30ad4 100644 --- a/model/channel.go +++ b/model/channel.go @@ -10,6 +10,7 @@ type Channel struct { Type int `json:"type" gorm:"default:0"` Key string `json:"key" gorm:"not null"` OpenAIOrganization *string `json:"openai_organization"` + TestModel *string `json:"test_model"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` diff --git a/model/redemption.go b/model/redemption.go index 122661f..00ec76b 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -56,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) { if common.UsingPostgreSQL { keyCol = `"key"` } - + common.RandomSleep() err = DB.Transaction(func(tx *gorm.DB) error { err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error if err != nil { diff --git a/model/user.go b/model/user.go index 22258d9..aa9060d 100644 --- a/model/user.go +++ b/model/user.go @@ -246,6 +246,27 @@ func (user *User) Update(updatePassword bool) error { if err == nil { if common.RedisEnabled { _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) + _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + } + } + return err +} + +func (user *User) UpdateAll(updatePassword bool) error { + var err error + if updatePassword { + user.Password, err = common.Password2Hash(user.Password) + if err != nil { + return err + } + } + newUser := *user + DB.First(&user, user.Id) + err = DB.Model(user).Select("*").Updates(newUser).Error + if err == nil { + if common.RedisEnabled { + _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) + _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) } } return err diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 4de8dc0..2b5d3d2 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -34,6 +34,7 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR StopSequences: nil, Temperature: textRequest.Temperature, TopP: textRequest.TopP, + TopK: textRequest.TopK, Stream: textRequest.Stream, } if claudeRequest.MaxTokensToSample == 0 { @@ -63,6 +64,7 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR StopSequences: nil, Temperature: textRequest.Temperature, TopP: textRequest.TopP, + TopK: textRequest.TopK, Stream: textRequest.Stream, } if claudeRequest.MaxTokens == 0 { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 27ed9a9..7ae9dd4 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -31,6 +31,7 @@ type RelayInfo struct { func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") userId := c.GetInt("id") group := c.GetString("group") diff --git a/relay/relay-text.go b/relay/relay-text.go index ff653ff..71a47c2 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -72,7 +72,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { textRequest, err := getAndValidateTextRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) + return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } // map model name @@ -82,7 +82,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[textRequest.Model] != "" { textRequest.Model = modelMap[textRequest.Model] @@ -103,7 +103,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // count messages token error 计算promptTokens错误 if err != nil { if sensitiveTrigger { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) } return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } @@ -162,7 +162,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return service.OpenAIErrorWrapper(fmt.Errorf("bad response status code: %d", resp.StatusCode), "bad_response_status_code", resp.StatusCode) + return service.RelayErrorHandler(resp) } usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) @@ -200,14 +200,14 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) { userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) if err != nil { - return 0, 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota <= 0 || userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 @@ -229,7 +229,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) if err != nil { - return 0, 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } return preConsumedQuota, userQuota, nil @@ -288,11 +288,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) //} quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + if quotaDelta != 0 { + err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } } - err = model.CacheUpdateUserQuota(relayInfo.UserId) + err := model.CacheUpdateUserQuota(relayInfo.UserId) if err != nil { common.LogError(ctx, "error update user quota cache: "+err.Error()) } diff --git a/service/channel.go b/service/channel.go index b9a7627..6ce444d 100644 --- a/service/channel.go +++ b/service/channel.go @@ -6,6 +6,7 @@ import ( "one-api/common" relaymodel "one-api/dto" "one-api/model" + "strings" ) // disable & notify @@ -33,7 +34,28 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool { if statusCode == http.StatusUnauthorized { return true } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" { + switch err.Code { + case "invalid_api_key": + return true + case "account_deactivated": + return true + case "billing_not_active": + return true + } + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { return true } return false diff --git a/service/error.go b/service/error.go index cda26b3..39eb0f9 100644 --- a/service/error.go +++ b/service/error.go @@ -46,6 +46,12 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError } } +func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { + openaiErr := OpenAIErrorWrapper(err, code, statusCode) + openaiErr.LocalError = true + return openaiErr +} + func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, diff --git a/web/.prettierrc.mjs b/web/.prettierrc.mjs new file mode 100644 index 0000000..ecae84d --- /dev/null +++ b/web/.prettierrc.mjs @@ -0,0 +1 @@ +module.exports = require("@so1ve/prettier-config"); diff --git a/web/.prettierrc.mjs b/web/.prettierrc.mjs deleted file mode 100644 index 7890fda..0000000 --- a/web/.prettierrc.mjs +++ /dev/null @@ -1 +0,0 @@ -module.exports = require("@so1ve/prettier-config"); \ No newline at end of file diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 804c7f5..5ac6a6c 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -471,10 +471,10 @@ const LogsTable = () => { }); }; - const refresh = async (localLogType) => { + const refresh = async () => { // setLoading(true); setActivePage(1); - await loadLogs(0, pageSize, localLogType); + await loadLogs(0, pageSize, logType); }; const copyText = async (text) => { @@ -637,7 +637,7 @@ const LogsTable = () => { style={{ width: 120 }} onChange={(value) => { setLogType(parseInt(value)); - refresh(parseInt(value)).then(); + loadLogs(0, pageSize, parseInt(value)); }} > 全部 diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index a71215e..b76b6c8 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -8,39 +8,41 @@ export function renderText(text, limit) { return text; } +/** + * Render group tags based on the input group string + * @param {string} group - The input group string + * @returns {JSX.Element} - The rendered group tags + */ export function renderGroup(group) { if (group === '') { - return default; + return ( + + unknown + + ); } - let groups = group.split(','); - groups.sort(); + + const tagColors = { + vip: 'yellow', + pro: 'yellow', + svip: 'red', + premium: 'red', + }; + + const groups = group.split(',').sort(); + return ( - <> - {groups.map((group) => { - if (group === 'vip' || group === 'pro') { - return ( - - {group} - - ); - } else if (group === 'svip' || group === 'premium') { - return ( - - {group} - - ); - } - if (group === 'default') { - return {group}; - } else { - return ( - - {group} - - ); - } - })} - + + {groups.map((group) => ( + + {group} + + ))} + ); } @@ -99,12 +101,29 @@ export function getQuotaPerUnit() { return quotaPerUnit; } +export function renderUnitWithQuota(quota) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); + quota = parseFloat(quota); + return quotaPerUnit * quota; +} + export function getQuotaWithUnit(quota, digits = 6) { let quotaPerUnit = localStorage.getItem('quota_per_unit'); quotaPerUnit = parseFloat(quotaPerUnit); return (quota / quotaPerUnit).toFixed(digits); } +export function renderQuotaWithAmount(amount) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return '$' + amount; + } else { + return renderUnitWithQuota(amount); + } +} + export function renderQuota(quota, digits = 2) { let quotaPerUnit = localStorage.getItem('quota_per_unit'); let displayInCurrency = localStorage.getItem('display_in_currency'); diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 9b98de2..0fe6e2b 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -63,6 +63,7 @@ const EditChannel = (props) => { model_mapping: '', models: [], auto_ban: 1, + test_model: '', groups: ['default'], }; const [batch, setBatch] = useState(false); @@ -669,6 +670,17 @@ const EditChannel = (props) => { }} value={inputs.openai_organization} /> +
+ 默认测试模型: +
+ { + handleInputChange('test_model', value); + }} + value={inputs.test_model} + />
{ + const navigate = useNavigate(); + const location = useLocation(); + const [tabActiveKey, setTabActiveKey] = useState('1'); let panes = [ { tab: '个人设置', content: , - itemKey: '1', + itemKey: 'personal', }, ]; @@ -19,28 +23,44 @@ const Setting = () => { panes.push({ tab: '运营设置', content: , - itemKey: '2', + itemKey: 'operation', }); panes.push({ tab: '系统设置', content: , - itemKey: '3', + itemKey: 'system', }); panes.push({ tab: '其他设置', content: , - itemKey: '4', + itemKey: 'other', }); } - + const onChangeTab = (key) => { + setTabActiveKey(key); + navigate(`?tab=${key}`); + }; + useEffect(() => { + const searchParams = new URLSearchParams(window.location.search); + const tab = searchParams.get('tab'); + if (tab) { + setTabActiveKey(tab); + } else { + onChangeTab('personal'); + } + }, [location.search]); return (
- + onChangeTab(key)} + > {panes.map((pane) => ( - {pane.content} + {tabActiveKey === pane.itemKey && pane.content} ))}