适配dall-e-3

This commit is contained in:
CaIon
2023-11-09 17:08:32 +08:00
parent 7237d20ec4
commit 3ff4210fc4
5 changed files with 224 additions and 179 deletions

View File

@@ -11,11 +11,10 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
@@ -31,14 +30,21 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e"
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.Contains(imageRequest.Size, "×") {
return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" &&
(imageRequest.Model == "dall-e-3" && (imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024")) {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
// N should between 1 and 10
@@ -55,8 +61,8 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
}
@@ -77,7 +83,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
modelRatio := common.GetModelRatio(imageRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
@@ -90,8 +96,19 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2.5
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
qualityRatio := 1.0
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
qualityRatio = 2.0
if imageRequest.Size == "1024×1792" || imageRequest.Size == "1792×1024" {
qualityRatio = 1.5
}
}
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -120,7 +137,6 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse ImageResponse
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
@@ -134,7 +150,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)