From 3ff4210fc472779734011177e269edaf95f72ede Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 9 Nov 2023 17:08:32 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Ddall-e-3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay-image.go | 36 ++- controller/relay.go | 8 +- model/redemption.go | 2 + web/src/components/LogsTable.js | 21 +- web/src/components/RedemptionsTable.js | 336 +++++++++++++------------ 5 files changed, 224 insertions(+), 179 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index 735976e..7421621 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -11,11 +11,10 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" ) func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - imageModel := "dall-e" - tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") @@ -31,14 +30,21 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e" + } // Prompt validation if imageRequest.Prompt == "" { return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) } + if strings.Contains(imageRequest.Size, "×") { + return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) + } // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" && + (imageRequest.Model == "dall-e-3" && (imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024")) { + return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) } // N should between 1 and 10 @@ -55,8 +61,8 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] + if modelMap[imageRequest.Model] != "" { + imageRequest.Model = modelMap[imageRequest.Model] isModelMapped = true } } @@ -77,7 +83,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageModel) + modelRatio := common.GetModelRatio(imageRequest.Model) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) @@ -90,8 +96,19 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode sizeRatio = 1.125 } else if imageRequest.Size == "1024x1024" { sizeRatio = 1.25 + } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" { + sizeRatio = 2.5 } - quota := int(ratio*sizeRatio*1000) * imageRequest.N + + qualityRatio := 1.0 + if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" { + qualityRatio = 2.0 + if imageRequest.Size == "1024×1792" || imageRequest.Size == "1792×1024" { + qualityRatio = 1.5 + } + } + + quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -120,7 +137,6 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } var textResponse ImageResponse - defer func(ctx context.Context) { if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota) @@ -134,7 +150,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay.go b/controller/relay.go index cd9d6bf..c505c22 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -85,9 +85,11 @@ type TextRequest struct { } type ImageRequest struct { - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` + Model string `json:"model"` + Quality string `json:"quality"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` } type AudioResponse struct { diff --git a/model/redemption.go b/model/redemption.go index f16412b..5e44817 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -17,6 +17,7 @@ type Redemption struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` Count int `json:"count" gorm:"-:all"` // only for api request + UsedUserId int `json:"used_user_id"` } func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { @@ -69,6 +70,7 @@ func Redeem(key string, userId int) (quota int, err error) { } redemption.RedeemedTime = common.GetTimestamp() redemption.Status = common.RedemptionCodeStatusUsed + redemption.UsedUserId = userId err = tx.Save(redemption).Error return err }) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 0101417..e2aa8aa 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -1,6 +1,6 @@ import React, {useEffect, useState} from 'react'; import {Label} from 'semantic-ui-react'; -import {API, isAdmin, showError, timestamp2string} from '../helpers'; +import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers'; import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal} from '@douyinfe/semi-ui'; import {ITEMS_PER_PAGE} from '../constants'; @@ -106,7 +106,9 @@ const LogsTable = () => { return ( record.type === 0 || record.type === 2 ?