mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-09 18:23:40 +08:00
Merge remote-tracking branch 'origin/upstream/main' into patch/images-edits
This commit is contained in:
@@ -7,6 +7,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
@@ -21,9 +25,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
@@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(audioModel)
|
||||
modelRatio := billingratio.GetModelRatio(audioModel, channelType)
|
||||
groupRatio := billingratio.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
var quota int64
|
||||
|
||||
@@ -4,6 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
@@ -16,9 +20,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
|
||||
@@ -40,78 +41,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
||||
imageRequest := &relaymodel.ImageRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
}
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func isValidImageSize(model string, size string) bool {
|
||||
if model == "cogview-3" {
|
||||
return true
|
||||
}
|
||||
_, ok := billingratio.ImageSizeRatios[model][size]
|
||||
return ok
|
||||
}
|
||||
|
||||
func getImageSizeRatio(model string, size string) float64 {
|
||||
ratio, ok := billingratio.ImageSizeRatios[model][size]
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
return ratio
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
// model validation
|
||||
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
|
||||
if !hasValidSize {
|
||||
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
||||
}
|
||||
// check prompt length
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
}
|
||||
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
|
||||
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||
}
|
||||
// Number of generated images validation
|
||||
if !isWithinRange(imageRequest.Model, imageRequest.N) {
|
||||
// channel not azure
|
||||
if meta.ChannelType != channeltype.Azure {
|
||||
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
||||
if imageRequest == nil {
|
||||
return 0, errors.New("imageRequest is nil")
|
||||
}
|
||||
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
|
||||
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size == "1024x1024" {
|
||||
imageCostRatio *= 2
|
||||
} else {
|
||||
imageCostRatio *= 1.5
|
||||
}
|
||||
}
|
||||
return imageCostRatio, nil
|
||||
}
|
||||
|
||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||
switch relayMode {
|
||||
case relaymode.ChatCompletions:
|
||||
@@ -167,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
||||
return
|
||||
}
|
||||
var quota int64
|
||||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
|
||||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
@@ -22,13 +23,84 @@ import (
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
|
||||
return false
|
||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
||||
imageRequest := &relaymodel.ImageRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
min := billingratio.ImageGenerationAmounts[element][0]
|
||||
max := billingratio.ImageGenerationAmounts[element][1]
|
||||
return value >= min && value <= max
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
}
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func isValidImageSize(model string, size string) bool {
|
||||
if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
|
||||
return true
|
||||
}
|
||||
_, ok := billingratio.ImageSizeRatios[model][size]
|
||||
return ok
|
||||
}
|
||||
|
||||
func isValidImagePromptLength(model string, promptLength int) bool {
|
||||
maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
|
||||
return !ok || promptLength <= maxPromptLength
|
||||
}
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
amounts, ok := billingratio.ImageGenerationAmounts[element]
|
||||
return !ok || (value >= amounts[0] && value <= amounts[1])
|
||||
}
|
||||
|
||||
func getImageSizeRatio(model string, size string) float64 {
|
||||
if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
|
||||
return ratio
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
// check prompt length
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// model validation
|
||||
if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
|
||||
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
|
||||
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Number of generated images validation
|
||||
if !isWithinRange(imageRequest.Model, imageRequest.N) {
|
||||
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
||||
if imageRequest == nil {
|
||||
return 0, errors.New("imageRequest is nil")
|
||||
}
|
||||
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
|
||||
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size == "1024x1024" {
|
||||
imageCostRatio *= 2
|
||||
} else {
|
||||
imageCostRatio *= 1.5
|
||||
}
|
||||
}
|
||||
return imageCostRatio, nil
|
||||
}
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
@@ -97,7 +169,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(imageModel)
|
||||
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
@@ -14,8 +17,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
@@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// get model ratio & group ratio
|
||||
modelRatio := billingratio.GetModelRatio(textRequest.Model)
|
||||
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
// pre-consume quota
|
||||
|
||||
Reference in New Issue
Block a user