diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 115558a5..75c6da51 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -21,4 +21,5 @@ const ( AvailableModels = "available_models" KeyRequestBody = "key_request_body" SystemPrompt = "system_prompt" + Meta = "meta" ) diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go index 42d49c0a..d85f42d1 100644 --- a/relay/adaptor/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return aiProxyLibraryRequest, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index 4aa8a11a..79b51ac3 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index bd0949be..a21e9ece 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/aws/adaptor.go b/relay/adaptor/aws/adaptor.go index 62221346..45bfbdf6 100644 --- a/relay/adaptor/aws/adaptor.go +++ b/relay/adaptor/aws/adaptor.go @@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/aws/utils/adaptor.go b/relay/adaptor/aws/utils/adaptor.go index 4cb880f2..f5fc0038 100644 --- a/relay/adaptor/aws/utils/adaptor.go +++ b/relay/adaptor/aws/utils/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go index 15306b95..664e0e77 100644 --- a/relay/adaptor/baidu/adaptor.go +++ b/relay/adaptor/baidu/adaptor.go @@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 97e3dbb2..8958466d 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -19,7 +19,7 @@ type Adaptor struct { } // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { return nil, errors.New("not implemented") } diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go index 6fdb1b04..dd90bd7b 100644 --- a/relay/adaptor/cohere/adaptor.go +++ b/relay/adaptor/cohere/adaptor.go @@ -15,7 +15,7 @@ import ( type Adaptor struct{} // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { return nil, errors.New("not implemented") } diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go index 44f560e8..21d91e76 100644 --- a/relay/adaptor/coze/adaptor.go +++ b/relay/adaptor/coze/adaptor.go @@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/deepl/adaptor.go b/relay/adaptor/deepl/adaptor.go index d018a096..5a03c261 100644 --- a/relay/adaptor/deepl/adaptor.go +++ b/relay/adaptor/deepl/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return convertedRequest, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index a86fde40..931bd54d 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -65,7 +65,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go index 01b2e2cb..88667561 100644 --- a/relay/adaptor/interface.go +++ b/relay/adaptor/interface.go @@ -13,7 +13,7 @@ type Adaptor interface { GetRequestURL(meta *meta.Meta) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - ConvertImageRequest(request *model.ImageRequest) (any, error) + ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go index ad1f8983..9305340d 100644 --- a/relay/adaptor/ollama/adaptor.go +++ b/relay/adaptor/ollama/adaptor.go @@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 6946e402..7b0a74b3 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -85,7 +85,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return request, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/palm/adaptor.go b/relay/adaptor/palm/adaptor.go index 98aa3e18..9b51562d 100644 --- a/relay/adaptor/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/proxy/adaptor.go b/relay/adaptor/proxy/adaptor.go index 670c7628..32984fc7 100644 --- a/relay/adaptor/proxy/adaptor.go +++ b/relay/adaptor/proxy/adaptor.go @@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { return nil, errors.Errorf("not implement") } diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go index a60a7de3..9c18a19f 100644 --- a/relay/adaptor/replicate/adaptor.go +++ b/relay/adaptor/replicate/adaptor.go @@ -23,7 +23,29 @@ type Adaptor struct { } // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { + return nil, errors.New("should call replicate.ConvertImageRequest instead") +} + +func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N != 1 && request.N != 0 { + return nil, errors.New("only support N=1") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { return DrawImageRequest{ Input: ImageInput{ Steps: 25, diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index 0de92d4a..de427305 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return tencentRequest, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go index 3fab4a45..131506eb 100644 --- a/relay/adaptor/vertexai/adaptor.go +++ b/relay/adaptor/vertexai/adaptor.go @@ -1,18 +1,20 @@ package vertexai import ( - "errors" "fmt" "io" "net/http" + "slices" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - relaymodel "github.com/songquanpeng/one-api/relay/model" + relayModel "github.com/songquanpeng/one-api/relay/model" ) var _ adaptor.Adaptor = new(Adaptor) @@ -24,14 +26,27 @@ type Adaptor struct{} func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") } - adaptor := GetAdaptor(request.Model) + adaptor := GetAdaptor(meta.ActualModelName) if adaptor == nil { - return nil, errors.New("adaptor not found") + return nil, errors.Errorf("cannot found vertex image adaptor for model %s", meta.ActualModelName) + } + + return adaptor.ConvertImageRequest(c, request) +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + meta := meta.GetByContext(c) + + adaptor := GetAdaptor(meta.ActualModelName) + if adaptor == nil { + return nil, errors.Errorf("cannot found vertex chat adaptor for model %s", meta.ActualModelName) } return adaptor.ConvertRequest(c, relayMode, request) @@ -40,9 +55,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { adaptor := GetAdaptor(meta.ActualModelName) if adaptor == nil { - return nil, &relaymodel.ErrorWithStatusCode{ + return nil, &relayModel.ErrorWithStatusCode{ StatusCode: http.StatusInternalServerError, - Error: relaymodel.Error{ + Error: relayModel.Error{ Message: "adaptor not found", }, } @@ -60,14 +75,19 @@ func (a *Adaptor) GetChannelName() string { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - suffix := "" - if strings.HasPrefix(meta.ActualModelName, "gemini") { + var suffix string + switch { + case strings.HasPrefix(meta.ActualModelName, "gemini"): if meta.IsStream { suffix = "streamGenerateContent?alt=sse" } else { suffix = "generateContent" } - } else { + case slices.Contains(imagen.ModelList, meta.ActualModelName): + return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/imagen-3.0-generate-001:predict", + meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region, + ), nil + default: if meta.IsStream { suffix = "streamRawPredict?alt=sse" } else { @@ -85,6 +105,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { suffix, ), nil } + return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", meta.Config.Region, @@ -105,13 +126,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - return request, nil -} - func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index cb911cfe..3080b54d 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -50,6 +50,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return req, nil } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + return nil, errors.New("not support image request") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = anthropic.StreamHandler(c, resp) diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index b5377875..871a616f 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -35,6 +35,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return geminiRequest, nil } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + return nil, errors.New("not support image request") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string diff --git a/relay/adaptor/vertexai/imagen/adaptor.go b/relay/adaptor/vertexai/imagen/adaptor.go new file mode 100644 index 00000000..50333d25 --- /dev/null +++ b/relay/adaptor/vertexai/imagen/adaptor.go @@ -0,0 +1,105 @@ +package imagen + +import ( + "encoding/json" + "io" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +var ModelList = []string{ + "imagen-3.0-generate-001", +} + +type Adaptor struct { +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N <= 0 { + return nil, errors.New("n must be greater than 0") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { + return CreateImageRequest{ + Instances: []createImageInstance{ + { + Prompt: request.Prompt, + }, + }, + Parameters: createImageParameters{ + SampleCount: request.N, + }, + }, nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, openai.ErrorWrapper( + errors.Wrap(err, "failed to read response body"), + "read_response_body", + http.StatusInternalServerError, + ) + } + + if resp.StatusCode != http.StatusOK { + return nil, openai.ErrorWrapper( + errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)), + "upstream_response", + http.StatusInternalServerError, + ) + } + + imagenResp := new(CreateImageResponse) + if err := json.Unmarshal(respBody, imagenResp); err != nil { + return nil, openai.ErrorWrapper( + errors.Wrap(err, "failed to decode response body"), + "unmarshal_upstream_response", + http.StatusInternalServerError, + ) + } + + if len(imagenResp.Predictions) == 0 { + return nil, openai.ErrorWrapper( + errors.New("empty predictions"), + "empty_predictions", + http.StatusInternalServerError, + ) + } + + oaiResp := openai.ImageResponse{ + Created: time.Now().Unix(), + } + for _, prediction := range imagenResp.Predictions { + oaiResp.Data = append(oaiResp.Data, openai.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + c.JSON(http.StatusOK, oaiResp) + return nil, nil +} diff --git a/relay/adaptor/vertexai/imagen/model.go b/relay/adaptor/vertexai/imagen/model.go new file mode 100644 index 00000000..b890d30d --- /dev/null +++ b/relay/adaptor/vertexai/imagen/model.go @@ -0,0 +1,23 @@ +package imagen + +type CreateImageRequest struct { + Instances []createImageInstance `json:"instances" binding:"required,min=1"` + Parameters createImageParameters `json:"parameters" binding:"required"` +} + +type createImageInstance struct { + Prompt string `json:"prompt"` +} + +type createImageParameters struct { + SampleCount int `json:"sample_count" binding:"required,min=1"` +} + +type CreateImageResponse struct { + Predictions []createImageResponsePrediction `json:"predictions"` +} + +type createImageResponsePrediction struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` +} diff --git a/relay/adaptor/vertexai/registry.go b/relay/adaptor/vertexai/registry.go index 41099f02..37dd06ef 100644 --- a/relay/adaptor/vertexai/registry.go +++ b/relay/adaptor/vertexai/registry.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude" gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" ) @@ -13,8 +14,9 @@ import ( type VertexAIModelType int const ( - VerterAIClaude VertexAIModelType = iota + 1 - VerterAIGemini + VertexAIClaude VertexAIModelType = iota + 1 + VertexAIGemini + VertexAIImagen ) var modelMapping = map[string]VertexAIModelType{} @@ -23,27 +25,35 @@ var modelList = []string{} func init() { modelList = append(modelList, claude.ModelList...) for _, model := range claude.ModelList { - modelMapping[model] = VerterAIClaude + modelMapping[model] = VertexAIClaude } modelList = append(modelList, gemini.ModelList...) for _, model := range gemini.ModelList { - modelMapping[model] = VerterAIGemini + modelMapping[model] = VertexAIGemini + } + + modelList = append(modelList, imagen.ModelList...) + for _, model := range imagen.ModelList { + modelMapping[model] = VertexAIImagen } } type innerAIAdapter interface { ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) } func GetAdaptor(model string) innerAIAdapter { adaptorType := modelMapping[model] switch adaptorType { - case VerterAIClaude: + case VertexAIClaude: return &claude.Adaptor{} - case VerterAIGemini: + case VertexAIGemini: return &gemini.Adaptor{} + case VertexAIImagen: + return &imagen.Adaptor{} default: return nil } diff --git a/relay/adaptor/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go index b5967f26..404ec767 100644 --- a/relay/adaptor/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 660bd379..1ae9e3f7 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index f83aa70c..1318a1a1 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -287,6 +287,9 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, // https://console.x.ai/ "grok-beta": 5.0 / 1000 * USD, + // vertex imagen3 + // https://cloud.google.com/vertex-ai/generative-ai/pricing#imagen-models + "imagen-3.0-generate-001": 0.02 * USD, // replicate charges based on the number of generated images // https://replicate.com/pricing "black-forest-labs/flux-1.1-pro": 0.04 * USD, diff --git a/relay/controller/image.go b/relay/controller/image.go index 1b69d97d..442a5c78 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -18,7 +18,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/meta" + metalib "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" ) @@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) @@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := meta.GetByContext(c) + meta := metalib.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) @@ -116,6 +116,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus meta.OriginModelName = imageRequest.Model imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model + metalib.Set2Context(c, meta) // model validation bizErr := validateImageRequest(imageRequest, meta) @@ -155,8 +156,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus case channeltype.Zhipu, channeltype.Ali, channeltype.Replicate, + channeltype.VertextAI, channeltype.Baidu: - finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) } diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index bcbe1045..db477622 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -34,6 +34,10 @@ type Meta struct { } func GetByContext(c *gin.Context) *Meta { + if v, ok := c.Get(ctxkey.Meta); ok { + return v.(*Meta) + } + meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path), ChannelType: c.GetInt(ctxkey.Channel), @@ -57,5 +61,11 @@ func GetByContext(c *gin.Context) *Meta { meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } meta.APIType = channeltype.ToAPIType(meta.ChannelType) + + Set2Context(c, &meta) return &meta } + +func Set2Context(c *gin.Context, meta *Meta) { + c.Set(ctxkey.Meta, meta) +}