@ -2,7 +2,6 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@ -14,72 +13,71 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"time"
)
func RelayImageHelper ( c * gin . Context , relayMode int ) * dto . OpenAIErrorWithStatusCode {
tokenId := c . GetInt ( "token_id" )
channelType := c . GetInt ( "channel" )
channelId := c . GetInt ( "channel_id" )
userId := c . GetInt ( "id" )
group := c . GetString ( "group" )
startTime := time . Now ( )
var imageRequest dto . ImageRequest
err := common . UnmarshalBodyReusable ( c , & imageRequest )
func getAndValidImageRequest ( c * gin . Context , info * relaycommon . RelayInfo ) ( * dto . ImageRequest , error ) {
imageRequest := & dto . ImageRequest { }
err := common . UnmarshalBodyReusable ( c , imageRequest )
if err != nil {
return service . OpenAIErrorWrapper ( err , "bind_request_body_failed" , http . StatusBadRequest )
return nil , err
}
if imageRequest . Model == "" {
imageRequest . Model = "dall-e-3"
if imageRequest . Prompt == "" {
return nil , errors . New ( "prompt is required" )
}
if imageRequest. Size == "" {
imageRequest . Size = "1024x1024"
if strings . Contains ( imageRequest . Size , "× " ) {
return nil , errors . New ( "size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '× '" )
}
if imageRequest . N == 0 {
imageRequest . N = 1
}
// Prompt validation
if imageRequest . Prompt == "" {
return service . OpenAIErrorWrapper ( errors . New ( "prompt is required" ) , "required_field_missing" , http . StatusBadRequest )
if imageRequest . Size == "" {
imageRequest . Size = "1024x1024"
}
if constant . ShouldCheckPromptSensitive ( ) {
err = service . CheckSensitiveInput ( imageRequest . Prompt )
if err != nil {
return service . OpenAIErrorWrapper ( err , "sensitive_words_detected" , http . StatusBadRequest )
if imageRequest . Model == "" {
imageRequest . Model = "dall-e-2"
}
}
if strings . Contains ( imageRequest . Size , "× " ) {
return service . OpenAIErrorWrapper ( errors . New ( "size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '× '" ) , "invalid_field_value" , http . StatusBadRequest )
if imageRequest . Quality == "" {
imageRequest . Quality = "standard"
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest . Model == "dall-e-2" || imageRequest . Model == "dall-e" {
if imageRequest . Size != "" && imageRequest . Size != "256x256" && imageRequest . Size != "512x512" && imageRequest . Size != "1024x1024" {
return service . OpenAIErrorWrapper ( errors . New ( "size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024" ) , "invalid_field_value" , http . StatusBadRequest )
return nil , errors . New ( "size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024" )
}
} else if imageRequest . Model == "dall-e-3" {
if imageRequest . Size != "" && imageRequest . Size != "1024x1024" && imageRequest . Size != "1024x1792" && imageRequest . Size != "1792x1024" {
return service . OpenAIErrorWrapper ( errors . New ( "size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024" ) , "invalid_field_value" , http . StatusBadRequest )
return nil , errors . New ( "size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024" )
}
if imageRequest . N != 1 {
return service . OpenAIErrorWrapper ( errors . New ( "n must be 1" ) , "invalid_field_value" , http . StatusBadRequest )
//if imageRequest.N != 1 {
// return nil, errors.New("n must be 1")
//}
}
}
// N should between 1 and 10
if imageRequest . N != 0 && ( imageRequest . N < 1 || imageRequest . N > 10 ) {
return service . OpenAIErrorWrapper ( errors . New ( "n must be between 1 and 10" ) , "invalid_field_value" , http . StatusBadRequest )
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if constant . ShouldCheckPromptSensitive ( ) {
err := service . CheckSensitiveInput ( imageRequest . Prompt )
if err != nil {
return nil , err
}
}
return imageRequest , nil
}
func ImageHelper ( c * gin . Context , relayMode int ) * dto . OpenAIErrorWithStatusCode {
relayInfo := relaycommon . GenRelayInfo ( c )
imageRequest , err := getAndValidImageRequest ( c , relayInfo )
if err != nil {
common . LogError ( c , fmt . Sprintf ( "getAndValidImageRequest failed: %s" , err . Error ( ) ) )
return service . OpenAIErrorWrapper ( err , "invalid_image_request" , http . StatusBadRequest )
}
// map model name
modelMapping := c . GetString ( "model_mapping" )
isModelMapped := false
if modelMapping != "" {
modelMap := make ( map [ string ] string )
err := json . Unmarshal ( [ ] byte ( modelMapping ) , & modelMap )
@ -88,31 +86,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
}
if modelMap [ imageRequest . Model ] != "" {
imageRequest . Model = modelMap [ imageRequest . Model ]
isModelMapped = true
}
}
baseURL := common . ChannelBaseURLs [ channelType ]
requestURL := c . Request . URL . String ( )
if c . GetString ( "base_url" ) != "" {
baseURL = c . GetString ( "base_url" )
}
fullRequestURL := relaycommon . GetFullRequestURL ( baseURL , requestURL , channelType )
if channelType == common . ChannelTypeAzure && relayMode == relayconstant . RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := relaycommon . GetAPIVersion ( c )
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt . Sprintf ( "%s/openai/deployments/%s/images/generations?api-version=%s" , baseURL , imageRequest . Model , apiVersion )
}
var requestBody io . Reader
if isModelMapped || channelType == common . ChannelTypeAzure { // make Azure channel request body
jsonStr , err := json . Marshal ( imageRequest )
if err != nil {
return service . OpenAIErrorWrapper ( err , "marshal_text_request_failed" , http . StatusInternalServerError )
}
requestBody = bytes . NewBuffer ( jsonStr )
} else {
requestBody = c . Request . Body
}
relayInfo . UpstreamModelName = imageRequest . Model
modelPrice , success := common . GetModelPrice ( imageRequest . Model , true )
if ! success {
@ -121,8 +97,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
// per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio
}
groupRatio := common . GetGroupRatio ( group )
userQuota , err := model . CacheGetUserQuota ( userId )
groupRatio := common . GetGroupRatio ( relayInfo . Group )
userQuota , err := model . CacheGetUserQuota ( relayInfo . UserId )
sizeRatio := 1.0
// Size
@ -150,98 +127,60 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
return service . OpenAIErrorWrapperLocal ( errors . New ( "user quota is not enough" ) , "insufficient_user_quota" , http . StatusForbidden )
}
req , err := http . NewRequest ( c . Request . Method , fullRequestURL , requestBody )
adaptor := GetAdaptor ( relayInfo . ApiType )
if adaptor == nil {
return service . OpenAIErrorWrapperLocal ( fmt . Errorf ( "invalid api type: %d" , relayInfo . ApiType ) , "invalid_api_type" , http . StatusBadRequest )
}
adaptor . Init ( relayInfo )
var requestBody io . Reader
convertedRequest , err := adaptor . ConvertImageRequest ( c , relayInfo , * imageRequest )
if err != nil {
return service . OpenAIErrorWrapper ( err , "new_request_failed" , http . StatusInternalServerError )
return service . OpenAIErrorWrapper Local( err , "convert _request_failed", http . StatusInternalServerError )
}
token := c . Request . Header . Get ( "Authorization" )
if channelType == common . ChannelTypeAzure { // Azure authentication
token = strings . TrimPrefix ( token , "Bearer " )
req . Header . Set ( "api-key" , token )
} else {
req . Header . Set ( "Authorization" , token )
jsonData , err := json . Marshal ( convertedRequest )
if err != nil {
return service . OpenAIErrorWrapperLocal ( err , "json_marshal_failed" , http . StatusInternalServerError )
}
req . Header . Set ( "Content-Type" , c . Request . Header . Get ( "Content-Type" ) )
req . Header . Set ( "Accept" , c . Request . Header . Get ( "Accept" ) )
requestBody = bytes . NewBuffer ( jsonData )
resp , err := service . GetHttpClient ( ) . Do ( req )
statusCodeMappingStr := c . GetString ( "status_code_mapping" )
resp , err := adaptor . DoRequest ( c , relayInfo , requestBody )
if err != nil {
return service . OpenAIErrorWrapper ( err , "do_request_failed" , http . StatusInternalServerError )
}
err = req . Body . Close ( )
if err != nil {
return service . OpenAIErrorWrapper ( err , "close_request_body_failed" , http . StatusInternalServerError )
if resp != nil {
relayInfo . IsStream = relayInfo . IsStream || strings . HasPrefix ( resp . Header . Get ( "Content-Type" ) , "text/event-stream" )
if resp . StatusCode != http . StatusOK {
openaiErr := service . RelayErrorHandler ( resp )
// reset status code 重置状态码
service . ResetStatusCode ( openaiErr , statusCodeMappingStr )
return openaiErr
}
err = c . Request . Body . Close ( )
if err != nil {
return service . OpenAIErrorWrapper ( err , "close_request_body_failed" , http . StatusInternalServerError )
}
if resp . StatusCode != http . StatusOK {
return service . RelayErrorHandler ( resp )
_ , openaiErr := adaptor . DoResponse ( c , resp , relayInfo )
if openaiErr != nil {
// reset status code 重置状态码
service . ResetStatusCode ( openaiErr , statusCodeMappingStr )
return openaiErr
}
var textResponse dto . ImageResponse
defer func ( ctx context . Context ) {
useTimeSeconds := time . Now ( ) . Unix ( ) - startTime . Unix ( )
if resp . StatusCode != http . StatusOK {
return
usage := & dto . Usage {
PromptTokens : relayInfo . PromptTokens ,
TotalTokens : relayInfo . PromptTokens ,
}
err := model . PostConsumeTokenQuota ( tokenId , userQuota , quota , 0 , true )
if err != nil {
common . SysError ( "error consuming token remain quota: " + err . Error ( ) )
}
err = model . CacheUpdateUserQuota ( userId )
if err != nil {
common . SysError ( "error update user quota cache: " + err . Error ( ) )
}
if quota != 0 {
tokenName := c . GetString ( "token_name" )
quality := "normal"
quality := "standard"
if imageRequest . Quality == "hd" {
quality = "hd"
}
logContent := fmt . Sprintf ( "模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s" , modelPrice , groupRatio , imageRequest . Size , quality )
other := make ( map [ string ] interface { } )
other [ "model_price" ] = modelPrice
other [ "group_ratio" ] = groupRatio
model . RecordConsumeLog ( ctx , userId , channelId , 0 , 0 , imageRequest . Model , tokenName , quota , logContent , tokenId , userQuota , int ( useTimeSeconds ) , false , other )
model . UpdateUserUsedQuotaAndRequestCount ( userId , quota )
channelId := c . GetInt ( "channel_id" )
model . UpdateChannelUsedQuota ( channelId , quota )
}
} ( c . Request . Context ( ) )
responseBody , err := io . ReadAll ( resp . Body )
logContent := fmt . Sprintf ( "大小 %s, 品质 %s" , imageRequest . Size , quality )
postConsumeQuota ( c , relayInfo , imageRequest . Model , usage , 0 , 0 , userQuota , 0 , groupRatio , modelPrice , true , logContent )
if err != nil {
return service . OpenAIErrorWrapper ( err , "read_response_body_failed" , http . StatusInternalServerError )
}
err = resp . Body . Close ( )
if err != nil {
return service . OpenAIErrorWrapper ( err , "close_response_body_failed" , http . StatusInternalServerError )
}
err = json . Unmarshal ( responseBody , & textResponse )
if err != nil {
return service . OpenAIErrorWrapper ( err , "unmarshal_response_body_failed" , http . StatusInternalServerError )
}
resp . Body = io . NopCloser ( bytes . NewBuffer ( responseBody ) )
for k , v := range resp . Header {
c . Writer . Header ( ) . Set ( k , v [ 0 ] )
}
c . Writer . WriteHeader ( resp . StatusCode )
_ , err = io . Copy ( c . Writer , resp . Body )
if err != nil {
return service . OpenAIErrorWrapper ( err , "copy_response_body_failed" , http . StatusInternalServerError )
}
err = resp . Body . Close ( )
if err != nil {
return service . OpenAIErrorWrapper ( err , "close_response_body_failed" , http . StatusInternalServerError )
}
return nil
}