add: add images api

This commit is contained in:
Martial BE
2023-12-01 17:20:22 +08:00
parent 5b70ee3407
commit 9dd92bbddd
19 changed files with 296 additions and 12 deletions

View File

@@ -63,6 +63,8 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode
usage, openAIErrorWithStatusCode = handleTranscriptions(c, provider, modelMap, quotaInfo, group)
case common.RelayModeAudioTranslation:
usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group)
case common.RelayModeImagesGenerations:
usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group)
default:
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
}
@@ -330,3 +332,47 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac
}
return speechProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens)
}
func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var imageRequest types.ImageRequest
isModelMapped := false
speechProvider, ok := provider.(providers_base.ImageGenerationsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
if modelMap != nil && modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
promptTokens, err := common.CountTokenImage(imageRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError)
}
quotaInfo.modelName = imageRequest.Model
quotaInfo.promptTokens = promptTokens
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
}

View File

@@ -242,9 +242,9 @@ func Relay(c *gin.Context) {
relayMode = common.RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = common.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = common.RelayModeImagesGenerations
}
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
// relayMode = RelayModeImagesGenerations
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
// relayMode = RelayModeEdits