mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	fix: fix size not support during image generation (#1564)
Fixes #1224, #1068
This commit is contained in:
		@@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
 | 
			
		||||
	return textRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.N == 0 {
 | 
			
		||||
		imageRequest.N = 1
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Size == "" {
 | 
			
		||||
		imageRequest.Size = "1024x1024"
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Model == "" {
 | 
			
		||||
		imageRequest.Model = "dall-e-2"
 | 
			
		||||
	}
 | 
			
		||||
	return imageRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImageSize(model string, size string) bool {
 | 
			
		||||
	if model == "cogview-3" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	_, ok := billingratio.ImageSizeRatios[model][size]
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	ratio, ok := billingratio.ImageSizeRatios[model][size]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
	return ratio
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// model validation
 | 
			
		||||
	hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if !hasValidSize {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	// check prompt length
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	// Number of generated images validation
 | 
			
		||||
	if !isWithinRange(imageRequest.Model, imageRequest.N) {
 | 
			
		||||
		// channel not azure
 | 
			
		||||
		if meta.ChannelType != channeltype.Azure {
 | 
			
		||||
			return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
 | 
			
		||||
	if imageRequest == nil {
 | 
			
		||||
		return 0, errors.New("imageRequest is nil")
 | 
			
		||||
	}
 | 
			
		||||
	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
 | 
			
		||||
		if imageRequest.Size == "1024x1024" {
 | 
			
		||||
			imageCostRatio *= 2
 | 
			
		||||
		} else {
 | 
			
		||||
			imageCostRatio *= 1.5
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return imageCostRatio, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
@@ -20,13 +21,84 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func isWithinRange(element string, value int) bool {
 | 
			
		||||
	if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
 | 
			
		||||
		return false
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	min := billingratio.ImageGenerationAmounts[element][0]
 | 
			
		||||
	max := billingratio.ImageGenerationAmounts[element][1]
 | 
			
		||||
	return value >= min && value <= max
 | 
			
		||||
	if imageRequest.N == 0 {
 | 
			
		||||
		imageRequest.N = 1
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Size == "" {
 | 
			
		||||
		imageRequest.Size = "1024x1024"
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.Model == "" {
 | 
			
		||||
		imageRequest.Model = "dall-e-2"
 | 
			
		||||
	}
 | 
			
		||||
	return imageRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImageSize(model string, size string) bool {
 | 
			
		||||
	if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	_, ok := billingratio.ImageSizeRatios[model][size]
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImagePromptLength(model string, promptLength int) bool {
 | 
			
		||||
	maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
 | 
			
		||||
	return !ok || promptLength <= maxPromptLength
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isWithinRange(element string, value int) bool {
 | 
			
		||||
	amounts, ok := billingratio.ImageGenerationAmounts[element]
 | 
			
		||||
	return !ok || (value >= amounts[0] && value <= amounts[1])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
 | 
			
		||||
		return ratio
 | 
			
		||||
	}
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// check prompt length
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// model validation
 | 
			
		||||
	if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Number of generated images validation
 | 
			
		||||
	if !isWithinRange(imageRequest.Model, imageRequest.N) {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
 | 
			
		||||
	if imageRequest == nil {
 | 
			
		||||
		return 0, errors.New("imageRequest is nil")
 | 
			
		||||
	}
 | 
			
		||||
	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
 | 
			
		||||
		if imageRequest.Size == "1024x1024" {
 | 
			
		||||
			imageCostRatio *= 2
 | 
			
		||||
		} else {
 | 
			
		||||
			imageCostRatio *= 1.5
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return imageCostRatio, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user