diff --git a/common/database.go b/common/database.go index c7e9fd5..c8f59a8 100644 --- a/common/database.go +++ b/common/database.go @@ -3,4 +3,4 @@ package common var UsingSQLite = false var UsingPostgreSQL = false -var SQLitePath = "one-api.db" +var SQLitePath = "one-api.db?_busy_timeout=5000" diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 9b08f85..f06d8b7 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -88,30 +88,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O return nil, responseText } -func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { var textResponse TextResponse - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - }, nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. @@ -120,7 +118,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err := io.Copy(c.Writer, resp.Body) + _, err = io.Copy(c.Writer, resp.Body) if err != nil { return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } diff --git a/controller/relay-text.go b/controller/relay-text.go index a009267..9b68504 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -50,14 +50,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") + startTime := time.Now() var textRequest GeneralOpenAIRequest - if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } if relayMode == RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" @@ -199,7 +198,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } var promptTokens int var completionTokens int - var err error switch relayMode { case RelayModeChatCompletions: promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model) @@ -236,7 +234,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { preConsumedQuota = 0 //common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } - if consumeQuota && preConsumedQuota > 0 { + if preConsumedQuota > 0 { userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) @@ -420,39 +418,40 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { defer func(ctx context.Context) { // c.Writer.Flush() go func() { - if consumeQuota { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens = textResponse.Usage.PromptTokens + completionTokens = textResponse.Usage.CompletionTokens - quota = promptTokens + int(float64(completionTokens)*completionRatio) - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } + quota = promptTokens + int(float64(completionTokens)*completionRatio) + quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } + // record all the consume log even if quota is 0 + useTimeSeconds := time.Now().Unix() - startTime.Unix() + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + //if quota != 0 { + // + //} }() }(c.Request.Context()) switch apiType { @@ -466,7 +465,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) return nil } else { - err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) + err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) if err != nil { return err } diff --git a/model/ability.go b/model/ability.go index 2e8766a..d115724 100644 --- a/model/ability.go +++ b/model/ability.go @@ -15,8 +15,8 @@ type Ability struct { func GetGroupModels(group string) []string { var abilities []Ability - //去重 - DB.Where("`group` = ?", group).Distinct("model").Find(&abilities) + //去重 enabled = true + DB.Where("`group` = ? and enabled = ?", group, true).Find(&abilities) models := make([]string, 0, len(abilities)) for _, ability := range abilities { models = append(models, ability.Model)