diff --git a/controller/relay.go b/controller/relay.go index a04c85a..bc951f7 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -22,13 +22,13 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: - err = relay.RelayImageHelper(c, relayMode) + err = relay.ImageHelper(c, relayMode) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c, relayMode) + err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) default: diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 820c2bc..2aa743f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -122,8 +122,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { @@ -142,6 +141,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom fallthrough case constant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) + case constant.RelayModeImagesGenerations: + err, usage = OpenaiTTSHandler(c, resp, info) default: if info.IsStream { err, usage = OpenaiStreamHandler(c, resp, info) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 2a0278e..b2fadcc 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return audioRequest, nil } -func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { +func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { relayInfo := relaycommon.GenRelayInfo(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) @@ -142,7 +142,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { return openaiErr } - postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false) + postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") return nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index 6d6e4d4..4b1fbd2 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "context" "encoding/json" "errors" "fmt" @@ -14,72 +13,71 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/service" "strings" - "time" ) -func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - var imageRequest dto.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) +func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return nil, err } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") } if imageRequest.N == 0 { imageRequest.N = 1 } - // Prompt validation - if imageRequest.Prompt == "" { - return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" } - - if constant.ShouldCheckPromptSensitive() { - err = service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) - } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" } - - if strings.Contains(imageRequest.Size, "×") { - return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" } // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") } - if imageRequest.N != 1 { - return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) + //if imageRequest.N != 1 { + // return nil, errors.New("n must be 1") + //} + } + // N should between 1 and 10 + //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { + // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + //} + if constant.ShouldCheckPromptSensitive() { + err := service.CheckSensitiveInput(imageRequest.Prompt) + if err != nil { + return nil, err } } + return imageRequest, nil +} - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) +func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + + imageRequest, err := getAndValidImageRequest(c, relayInfo) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name modelMapping := c.GetString("model_mapping") - isModelMapped := false if modelMapping != "" { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) @@ -88,31 +86,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if modelMap[imageRequest.Model] != "" { imageRequest.Model = modelMap[imageRequest.Model] - isModelMapped = true } } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := relaycommon.GetAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) - } - var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body - jsonStr, err := json.Marshal(imageRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } + relayInfo.UpstreamModelName = imageRequest.Model modelPrice, success := common.GetModelPrice(imageRequest.Model, true) if !success { @@ -121,8 +97,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC // per 1 modelRatio = $0.04 / 16 modelPrice = 0.0025 * modelRatio } - groupRatio := common.GetGroupRatio(group) - userQuota, err := model.CacheGetUserQuota(userId) + + groupRatio := common.GetGroupRatio(relayInfo.Group) + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) sizeRatio := 1.0 // Size @@ -150,98 +127,60 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + + var requestBody io.Reader + + convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) + requestBody = bytes.NewBuffer(jsonData) - resp, err := service.GetHttpClient().Do(req) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - - if resp.StatusCode != http.StatusOK { - return service.RelayErrorHandler(resp) - } - - var textResponse dto.ImageResponse - defer func(ctx context.Context) { - useTimeSeconds := time.Now().Unix() - startTime.Unix() + if resp != nil { + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { - return + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } - err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - quality := "normal" - if imageRequest.Quality == "hd" { - quality = "hd" - } - logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality) - other := make(map[string]interface{}) - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + _, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) + usage := &dto.Usage{ + PromptTokens: relayInfo.PromptTokens, + TotalTokens: relayInfo.PromptTokens, } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + quality := "standard" + if imageRequest.Quality == "hd" { + quality = "hd" } + + logContent := fmt.Sprintf(", 大小 %s, 品质 %s", imageRequest.Size, quality) + postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent) + return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 9e1b9b7..d82bd60 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess) + postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } @@ -279,7 +279,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool) { + modelPrice float64, usePrice bool, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -338,6 +338,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN logModel = "gpt-4-gizmo-*" logContent += fmt.Sprintf(",模型 %s", modelName) } + if extraContent != "" { + logContent += extraContent + } other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 2fc4854..9885fd3 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -99,6 +99,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") return nil }