add: add images edits and variations API

This commit is contained in:
Martial BE
2023-12-01 18:25:05 +08:00
parent 9dd92bbddd
commit 0f038d715d
11 changed files with 302 additions and 24 deletions

View File

@@ -65,6 +65,10 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode
usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group)
case common.RelayModeImagesGenerations:
usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group)
case common.RelayModeImagesEdits:
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "edit")
case common.RelayModeImagesVariations:
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "variation")
default:
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
}
@@ -336,7 +340,7 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac
func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var imageRequest types.ImageRequest
isModelMapped := false
speechProvider, ok := provider.(providers_base.ImageGenerationsInterface)
imageGenerationsProvider, ok := provider.(providers_base.ImageGenerationsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
@@ -374,5 +378,60 @@ func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInte
if quota_err != nil {
return nil, quota_err
}
return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
return imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
}
func handleImageEdits(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string, imageType string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var imageEditRequest types.ImageEditRequest
isModelMapped := false
var imageEditsProvider providers_base.ImageEditsInterface
var imageVariations providers_base.ImageVariationsInterface
var ok bool
if imageType == "edit" {
imageEditsProvider, ok = provider.(providers_base.ImageEditsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
} else {
imageVariations, ok = provider.(providers_base.ImageVariationsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
}
err := common.UnmarshalBodyReusable(c, &imageEditRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageEditRequest.Model == "" {
imageEditRequest.Model = "dall-e-2"
}
if imageEditRequest.Size == "" {
imageEditRequest.Size = "1024x1024"
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
promptTokens, err := common.CountTokenImage(imageEditRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError)
}
quotaInfo.modelName = imageEditRequest.Model
quotaInfo.promptTokens = promptTokens
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
if imageType == "edit" {
return imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
}
return imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
}

View File

@@ -244,6 +244,10 @@ func Relay(c *gin.Context) {
relayMode = common.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = common.RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
relayMode = common.RelayModeImagesEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/variations") {
relayMode = common.RelayModeImagesVariations
}
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
// relayMode = RelayModeEdits