feat: support vertex imagen3

This commit is contained in:
Laisky.Cai 2025-01-12 04:18:57 +00:00
parent 3915ce9814
commit 009337ccf3
30 changed files with 247 additions and 49 deletions

View File

@ -21,4 +21,5 @@ const (
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body" KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt" SystemPrompt = "system_prompt"
Meta = "meta"
) )

View File

@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return aiProxyLibraryRequest, nil return aiProxyLibraryRequest, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil return nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil return nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -19,7 +19,7 @@ type Adaptor struct {
} }
// ConvertImageRequest implements adaptor.Adaptor. // ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }

View File

@ -15,7 +15,7 @@ import (
type Adaptor struct{} type Adaptor struct{}
// ConvertImageRequest implements adaptor.Adaptor. // ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }

View File

@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return convertedRequest, nil return convertedRequest, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -65,7 +65,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -13,7 +13,7 @@ type Adaptor interface {
GetRequestURL(meta *meta.Meta) (string, error) GetRequestURL(meta *meta.Meta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(request *model.ImageRequest) (any, error) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string GetModelList() []string

View File

@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -85,7 +85,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return request, nil return request, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil return nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.Errorf("not implement") return nil, errors.Errorf("not implement")
} }

View File

@ -23,7 +23,29 @@ type Adaptor struct {
} }
// ConvertImageRequest implements adaptor.Adaptor. // ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("should call replicate.ConvertImageRequest instead")
}
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
meta := meta.GetByContext(c)
if request.ResponseFormat != "b64_json" {
return nil, errors.New("only support b64_json response format")
}
if request.N != 1 && request.N != 0 {
return nil, errors.New("only support N=1")
}
switch meta.Mode {
case relaymode.ImagesGenerations:
return convertImageCreateRequest(request)
default:
return nil, errors.New("not implemented")
}
}
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
return DrawImageRequest{ return DrawImageRequest{
Input: ImageInput{ Input: ImageInput{
Steps: 25, Steps: 25,

View File

@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return tencentRequest, nil return tencentRequest, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -1,18 +1,20 @@
package vertexai package vertexai
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"slices"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor" channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
relaymodel "github.com/songquanpeng/one-api/relay/model" relayModel "github.com/songquanpeng/one-api/relay/model"
) )
var _ adaptor.Adaptor = new(Adaptor) var _ adaptor.Adaptor = new(Adaptor)
@ -24,14 +26,27 @@ type Adaptor struct{}
func (a *Adaptor) Init(meta *meta.Meta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { meta := meta.GetByContext(c)
return nil, errors.New("request is nil")
if request.ResponseFormat != "b64_json" {
return nil, errors.New("only support b64_json response format")
} }
adaptor := GetAdaptor(request.Model) adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil { if adaptor == nil {
return nil, errors.New("adaptor not found") return nil, errors.Errorf("cannot found vertex image adaptor for model %s", meta.ActualModelName)
}
return adaptor.ConvertImageRequest(c, request)
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
meta := meta.GetByContext(c)
adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil {
return nil, errors.Errorf("cannot found vertex chat adaptor for model %s", meta.ActualModelName)
} }
return adaptor.ConvertRequest(c, relayMode, request) return adaptor.ConvertRequest(c, relayMode, request)
@ -40,9 +55,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
adaptor := GetAdaptor(meta.ActualModelName) adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil { if adaptor == nil {
return nil, &relaymodel.ErrorWithStatusCode{ return nil, &relayModel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError, StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{ Error: relayModel.Error{
Message: "adaptor not found", Message: "adaptor not found",
}, },
} }
@ -60,14 +75,19 @@ func (a *Adaptor) GetChannelName() string {
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
suffix := "" var suffix string
if strings.HasPrefix(meta.ActualModelName, "gemini") { switch {
case strings.HasPrefix(meta.ActualModelName, "gemini"):
if meta.IsStream { if meta.IsStream {
suffix = "streamGenerateContent?alt=sse" suffix = "streamGenerateContent?alt=sse"
} else { } else {
suffix = "generateContent" suffix = "generateContent"
} }
} else { case slices.Contains(imagen.ModelList, meta.ActualModelName):
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/imagen-3.0-generate-001:predict",
meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region,
), nil
default:
if meta.IsStream { if meta.IsStream {
suffix = "streamRawPredict?alt=sse" suffix = "streamRawPredict?alt=sse"
} else { } else {
@ -85,6 +105,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
suffix, suffix,
), nil ), nil
} }
return fmt.Sprintf( return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
meta.Config.Region, meta.Config.Region,
@ -105,13 +126,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil return nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody) return channelhelper.DoRequestHelper(a, c, meta, requestBody)
} }

View File

@ -50,6 +50,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return req, nil return req, nil
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("not support image request")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = anthropic.StreamHandler(c, resp) err, usage = anthropic.StreamHandler(c, resp)

View File

@ -35,6 +35,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return geminiRequest, nil return geminiRequest, nil
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("not support image request")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string

View File

@ -0,0 +1,105 @@
package imagen
import (
"encoding/json"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
var ModelList = []string{
"imagen-3.0-generate-001",
}
type Adaptor struct {
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
meta := meta.GetByContext(c)
if request.ResponseFormat != "b64_json" {
return nil, errors.New("only support b64_json response format")
}
if request.N <= 0 {
return nil, errors.New("n must be greater than 0")
}
switch meta.Mode {
case relaymode.ImagesGenerations:
return convertImageCreateRequest(request)
default:
return nil, errors.New("not implemented")
}
}
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
return CreateImageRequest{
Instances: []createImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: createImageParameters{
SampleCount: request.N,
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to read response body"),
"read_response_body",
http.StatusInternalServerError,
)
}
if resp.StatusCode != http.StatusOK {
return nil, openai.ErrorWrapper(
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
"upstream_response",
http.StatusInternalServerError,
)
}
imagenResp := new(CreateImageResponse)
if err := json.Unmarshal(respBody, imagenResp); err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to decode response body"),
"unmarshal_upstream_response",
http.StatusInternalServerError,
)
}
if len(imagenResp.Predictions) == 0 {
return nil, openai.ErrorWrapper(
errors.New("empty predictions"),
"empty_predictions",
http.StatusInternalServerError,
)
}
oaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
}
for _, prediction := range imagenResp.Predictions {
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
c.JSON(http.StatusOK, oaiResp)
return nil, nil
}

View File

@ -0,0 +1,23 @@
package imagen
type CreateImageRequest struct {
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
Parameters createImageParameters `json:"parameters" binding:"required"`
}
type createImageInstance struct {
Prompt string `json:"prompt"`
}
type createImageParameters struct {
SampleCount int `json:"sample_count" binding:"required,min=1"`
}
type CreateImageResponse struct {
Predictions []createImageResponsePrediction `json:"predictions"`
}
type createImageResponsePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
}

View File

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude" claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini" gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@ -13,8 +14,9 @@ import (
type VertexAIModelType int type VertexAIModelType int
const ( const (
VerterAIClaude VertexAIModelType = iota + 1 VertexAIClaude VertexAIModelType = iota + 1
VerterAIGemini VertexAIGemini
VertexAIImagen
) )
var modelMapping = map[string]VertexAIModelType{} var modelMapping = map[string]VertexAIModelType{}
@ -23,27 +25,35 @@ var modelList = []string{}
func init() { func init() {
modelList = append(modelList, claude.ModelList...) modelList = append(modelList, claude.ModelList...)
for _, model := range claude.ModelList { for _, model := range claude.ModelList {
modelMapping[model] = VerterAIClaude modelMapping[model] = VertexAIClaude
} }
modelList = append(modelList, gemini.ModelList...) modelList = append(modelList, gemini.ModelList...)
for _, model := range gemini.ModelList { for _, model := range gemini.ModelList {
modelMapping[model] = VerterAIGemini modelMapping[model] = VertexAIGemini
}
modelList = append(modelList, imagen.ModelList...)
for _, model := range imagen.ModelList {
modelMapping[model] = VertexAIImagen
} }
} }
type innerAIAdapter interface { type innerAIAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
} }
func GetAdaptor(model string) innerAIAdapter { func GetAdaptor(model string) innerAIAdapter {
adaptorType := modelMapping[model] adaptorType := modelMapping[model]
switch adaptorType { switch adaptorType {
case VerterAIClaude: case VertexAIClaude:
return &claude.Adaptor{} return &claude.Adaptor{}
case VerterAIGemini: case VertexAIGemini:
return &gemini.Adaptor{} return &gemini.Adaptor{}
case VertexAIImagen:
return &imagen.Adaptor{}
default: default:
return nil return nil
} }

View File

@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, nil return nil, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@ -287,6 +287,9 @@ var ModelRatio = map[string]float64{
"deepl-ja": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD,
// https://console.x.ai/ // https://console.x.ai/
"grok-beta": 5.0 / 1000 * USD, "grok-beta": 5.0 / 1000 * USD,
// vertex imagen3
// https://cloud.google.com/vertex-ai/generative-ai/pricing#imagen-models
"imagen-3.0-generate-001": 0.02 * USD,
// replicate charges based on the number of generated images // replicate charges based on the number of generated images
// https://replicate.com/pricing // https://replicate.com/pricing
"black-forest-labs/flux-1.1-pro": 0.04 * USD, "black-forest-labs/flux-1.1-pro": 0.04 * USD,

View File

@ -18,7 +18,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" metalib "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
return 1 return 1
} }
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length // check prompt length
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context() ctx := c.Request.Context()
meta := meta.GetByContext(c) meta := metalib.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode) imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil { if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
@ -116,6 +116,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
meta.OriginModelName = imageRequest.Model meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model meta.ActualModelName = imageRequest.Model
metalib.Set2Context(c, meta)
// model validation // model validation
bizErr := validateImageRequest(imageRequest, meta) bizErr := validateImageRequest(imageRequest, meta)
@ -155,8 +156,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
case channeltype.Zhipu, case channeltype.Zhipu,
channeltype.Ali, channeltype.Ali,
channeltype.Replicate, channeltype.Replicate,
channeltype.VertextAI,
channeltype.Baidu: channeltype.Baidu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest) finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
} }

View File

@ -34,6 +34,10 @@ type Meta struct {
} }
func GetByContext(c *gin.Context) *Meta { func GetByContext(c *gin.Context) *Meta {
if v, ok := c.Get(ctxkey.Meta); ok {
return v.(*Meta)
}
meta := Meta{ meta := Meta{
Mode: relaymode.GetByPath(c.Request.URL.Path), Mode: relaymode.GetByPath(c.Request.URL.Path),
ChannelType: c.GetInt(ctxkey.Channel), ChannelType: c.GetInt(ctxkey.Channel),
@ -57,5 +61,11 @@ func GetByContext(c *gin.Context) *Meta {
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
} }
meta.APIType = channeltype.ToAPIType(meta.ChannelType) meta.APIType = channeltype.ToAPIType(meta.ChannelType)
Set2Context(c, &meta)
return &meta return &meta
} }
func Set2Context(c *gin.Context, meta *Meta) {
c.Set(ctxkey.Meta, meta)
}