diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go index b0a4aa4f..3eea99a2 100644 --- a/controller/auth/oidc.go +++ b/controller/auth/oidc.go @@ -1,13 +1,13 @@ package auth import ( - "strings" "encoding/json" "errors" "fmt" "net/http" "net/url" "strconv" + "strings" "time" "github.com/gin-contrib/sessions" diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go index 39e87262..50a8e1da 100644 --- a/relay/adaptor/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -132,12 +132,14 @@ type EmbeddingResponse struct { model.Usage `json:"usage"` } +// ImageData represents an image in the response type ImageData struct { Url string `json:"url,omitempty"` B64Json string `json:"b64_json,omitempty"` RevisedPrompt string `json:"revised_prompt,omitempty"` } +// ImageResponse represents the response structure for image generations type ImageResponse struct { Created int64 `json:"created"` Data []ImageData `json:"data"` diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go index 296a9aaa..a37d98c9 100644 --- a/relay/adaptor/replicate/adaptor.go +++ b/relay/adaptor/replicate/adaptor.go @@ -73,12 +73,12 @@ func convertImageRemixRequest(c *gin.Context) (any, error) { } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - rawReq := new(OpenaiImageEditRequest) + rawReq := new(model.OpenaiImageEditRequest) if err := c.ShouldBind(rawReq); err != nil { return nil, errors.Wrap(err, "parse image edit form") } - return rawReq.toFluxRemixRequest() + return Convert2FluxRemixRequest(rawReq) } // ConvertRequest converts the request to the format that the target API expects. diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go index 8a1d97aa..79bf4b16 100644 --- a/relay/adaptor/replicate/model.go +++ b/relay/adaptor/replicate/model.go @@ -6,22 +6,12 @@ import ( "image" "image/png" "io" - "mime/multipart" "time" "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/model" ) -type OpenaiImageEditRequest struct { - Image *multipart.FileHeader `json:"image" form:"image" binding:"required"` - Prompt string `json:"prompt" form:"prompt" binding:"required"` - Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"` - Model string `json:"model" form:"model" binding:"required"` - N int `json:"n" form:"n" binding:"min=0,max=10"` - Size string `json:"size" form:"size"` - ResponseFormat string `json:"response_format" form:"response_format"` -} - // toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request. // // Note that the mask formats of OpenAI and Flux are different: @@ -31,14 +21,14 @@ type OpenaiImageEditRequest struct { // // Both OpenAI's Image and Mask are browser-native ImageData, // which need to be converted to base64 dataURI format. -func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) { - if r.ResponseFormat != "b64_json" { +func Convert2FluxRemixRequest(req *model.OpenaiImageEditRequest) (*InpaintingImageByFlusReplicateRequest, error) { + if req.ResponseFormat != "b64_json" { return nil, errors.New("response_format must be b64_json for replicate models") } fluxReq := &InpaintingImageByFlusReplicateRequest{ Input: FluxInpaintingInput{ - Prompt: r.Prompt, + Prompt: req.Prompt, Seed: int(time.Now().UnixNano()), Steps: 30, Guidance: 3, @@ -48,7 +38,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep }, } - imgFile, err := r.Image.Open() + imgFile, err := req.Image.Open() if err != nil { return nil, errors.Wrap(err, "open image file") } @@ -58,7 +48,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep return nil, errors.Wrap(err, "read image file") } - maskFile, err := r.Mask.Open() + maskFile, err := req.Mask.Open() if err != nil { return nil, errors.Wrap(err, "open mask file") } diff --git a/relay/adaptor/replicate/model_test.go b/relay/adaptor/replicate/model_test.go index 6cde5e94..6042e06f 100644 --- a/relay/adaptor/replicate/model_test.go +++ b/relay/adaptor/replicate/model_test.go @@ -10,6 +10,7 @@ import ( "net/http" "testing" + "github.com/songquanpeng/one-api/relay/model" "github.com/stretchr/testify/require" ) @@ -50,7 +51,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) { maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes()) require.NoError(t, err) - req := &OpenaiImageEditRequest{ + req := &model.OpenaiImageEditRequest{ Image: imgFileHeader, Mask: maskFileHeader, Prompt: "Test prompt", @@ -58,7 +59,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) { ResponseFormat: "b64_json", } - fluxReq, err := req.toFluxRemixRequest() + fluxReq, err := Convert2FluxRemixRequest(req) require.NoError(t, err) require.NotNil(t, fluxReq) require.Equal(t, req.Prompt, fluxReq.Input.Prompt) diff --git a/relay/adaptor/vertexai/imagen/adaptor.go b/relay/adaptor/vertexai/imagen/adaptor.go index 27c82a14..3a7c868a 100644 --- a/relay/adaptor/vertexai/imagen/adaptor.go +++ b/relay/adaptor/vertexai/imagen/adaptor.go @@ -1,6 +1,7 @@ package imagen import ( + "bytes" "encoding/json" "io" "net/http" @@ -15,16 +16,28 @@ import ( ) var ModelList = []string{ - // create + // ------------------------------------- + // generate + // ------------------------------------- "imagen-3.0-generate-001", "imagen-3.0-generate-002", "imagen-3.0-fast-generate-001", + // ------------------------------------- // edit - // "imagen-3.0-capability-001", // not supported yet + // ------------------------------------- + "imagen-3.0-capability-001", } type Adaptor struct { } +func (a *Adaptor) Init(meta *meta.Meta) { + // No initialization needed +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + return nil, errors.New("not implemented") +} + func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { meta := meta.GetByContext(c) @@ -32,19 +45,91 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageReques return nil, errors.New("only support b64_json response format") } if request.N <= 0 { - return nil, errors.New("n must be greater than 0") + request.N = 1 // Default to 1 if not specified } switch meta.Mode { case relaymode.ImagesGenerations: return convertImageCreateRequest(request) case relaymode.ImagesEdits: - return nil, errors.New("not implemented") + switch c.ContentType() { + // case "application/json": + // return ConvertJsonImageEditRequest(c) + case "multipart/form-data": + return ConvertMultipartImageEditRequest(c) + default: + return nil, errors.New("unsupported content type for image edit") + } default: return nil, errors.New("not implemented") } } +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(respBody)) + + switch meta.Mode { + case relaymode.ImagesEdits: + return HandleImageEdit(c, resp) + case relaymode.ImagesGenerations: + return nil, handleImageGeneration(c, resp, respBody) + default: + return nil, openai.ErrorWrapper(errors.New("unsupported mode"), "unsupported_mode", http.StatusBadRequest) + } +} + +func handleImageGeneration(c *gin.Context, resp *http.Response, respBody []byte) *model.ErrorWithStatusCode { + var imageResponse CreateImageResponse + + if resp.StatusCode != http.StatusOK { + return openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode) + } + + err := json.Unmarshal(respBody, &imageResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + // Convert to OpenAI format + openaiResp := openai.ImageResponse{ + Created: time.Now().Unix(), + Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)), + } + + for _, prediction := range imageResponse.Predictions { + openaiResp.Data = append(openaiResp.Data, openai.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + respBytes, err := json.Marshal(openaiResp) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(respBytes) + if err != nil { + return openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) + } + + return nil +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "vertex_ai_imagen" +} + func convertImageCreateRequest(request *model.ImageRequest) (any, error) { return CreateImageRequest{ Instances: []createImageInstance{ @@ -57,55 +142,3 @@ func convertImageCreateRequest(request *model.ImageRequest) (any, error) { }, }, nil } - -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - return nil, errors.New("not implemented") -} - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, openai.ErrorWrapper( - errors.Wrap(err, "failed to read response body"), - "read_response_body", - http.StatusInternalServerError, - ) - } - - if resp.StatusCode != http.StatusOK { - return nil, openai.ErrorWrapper( - errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)), - "upstream_response", - http.StatusInternalServerError, - ) - } - - imagenResp := new(CreateImageResponse) - if err := json.Unmarshal(respBody, imagenResp); err != nil { - return nil, openai.ErrorWrapper( - errors.Wrap(err, "failed to decode response body"), - "unmarshal_upstream_response", - http.StatusInternalServerError, - ) - } - - if len(imagenResp.Predictions) == 0 { - return nil, openai.ErrorWrapper( - errors.New("empty predictions"), - "empty_predictions", - http.StatusInternalServerError, - ) - } - - oaiResp := openai.ImageResponse{ - Created: time.Now().Unix(), - } - for _, prediction := range imagenResp.Predictions { - oaiResp.Data = append(oaiResp.Data, openai.ImageData{ - B64Json: prediction.BytesBase64Encoded, - }) - } - - c.JSON(http.StatusOK, oaiResp) - return nil, nil -} diff --git a/relay/adaptor/vertexai/imagen/image.go b/relay/adaptor/vertexai/imagen/image.go new file mode 100644 index 00000000..905552f6 --- /dev/null +++ b/relay/adaptor/vertexai/imagen/image.go @@ -0,0 +1,168 @@ +package imagen + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +// ConvertImageEditRequest handles the conversion from multipart form to Imagen request format +func ConvertMultipartImageEditRequest(c *gin.Context) (*CreateImageRequest, error) { + // Recover request body for binding + requestBody, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.Wrap(err, "get request body") + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + // Parse the form + var rawReq model.OpenaiImageEditRequest + if err := c.ShouldBind(&rawReq); err != nil { + return nil, errors.Wrap(err, "parse image edit form") + } + + // Validate response format + if rawReq.ResponseFormat != "b64_json" { + return nil, errors.New("response_format must be b64_json for Imagen models") + } + + // Set default N if not provided + if rawReq.N <= 0 { + rawReq.N = 1 + } + + // Set default edit mode if not provided + editMode := "EDIT_MODE_INPAINT_INSERTION" + if rawReq.EditMode != nil { + editMode = *rawReq.EditMode + } + + // Set default mask mode if not provided + maskMode := "MASK_MODE_USER_PROVIDED" + if rawReq.MaskMode != nil { + maskMode = *rawReq.MaskMode + } + + // Process the image file + imgFile, err := rawReq.Image.Open() + if err != nil { + return nil, errors.Wrap(err, "open image file") + } + defer imgFile.Close() + + imgData, err := io.ReadAll(imgFile) + if err != nil { + return nil, errors.Wrap(err, "read image file") + } + + // Process the mask file + maskFile, err := rawReq.Mask.Open() + if err != nil { + return nil, errors.Wrap(err, "open mask file") + } + defer maskFile.Close() + + maskData, err := io.ReadAll(maskFile) + if err != nil { + return nil, errors.Wrap(err, "read mask file") + } + + // Convert to base64 + imgBase64 := base64.StdEncoding.EncodeToString(imgData) + maskBase64 := base64.StdEncoding.EncodeToString(maskData) + + // Create the request + req := &CreateImageRequest{ + Instances: []createImageInstance{ + { + Prompt: rawReq.Prompt, + ReferenceImages: []ReferenceImage{ + { + ReferenceType: "REFERENCE_TYPE_RAW", + ReferenceId: 1, + ReferenceImage: ReferenceImageData{ + BytesBase64Encoded: imgBase64, + }, + }, + { + ReferenceType: "REFERENCE_TYPE_MASK", + ReferenceId: 2, + ReferenceImage: ReferenceImageData{ + BytesBase64Encoded: maskBase64, + }, + MaskImageConfig: &MaskImageConfig{ + MaskMode: maskMode, + }, + }, + }, + }, + }, + Parameters: createImageParameters{ + SampleCount: rawReq.N, + EditMode: &editMode, + }, + } + + return req, nil +} + +// HandleImageEdit processes an image edit response from Imagen API +func HandleImageEdit(c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode) + } + + var imageResponse CreateImageResponse + err = json.Unmarshal(respBody, &imageResponse) + if err != nil { + return nil, openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + // Convert to OpenAI format + openaiResp := openai.ImageResponse{ + Created: time.Now().Unix(), + Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)), + } + + for _, prediction := range imageResponse.Predictions { + openaiResp.Data = append(openaiResp.Data, openai.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + respBytes, err := json.Marshal(openaiResp) + if err != nil { + return nil, openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(respBytes) + if err != nil { + return nil, openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) + } + + // Create usage data (minimal as this API doesn't provide token counts) + usage := &model.Usage{ + PromptTokens: 0, + CompletionTokens: 0, + TotalTokens: 0, + } + + return usage, nil +} diff --git a/relay/adaptor/vertexai/imagen/model.go b/relay/adaptor/vertexai/imagen/model.go index b890d30d..1d6839b5 100644 --- a/relay/adaptor/vertexai/imagen/model.go +++ b/relay/adaptor/vertexai/imagen/model.go @@ -1,18 +1,80 @@ package imagen +// CreateImageRequest is the request body for the Imagen API. +// +// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api type CreateImageRequest struct { Instances []createImageInstance `json:"instances" binding:"required,min=1"` Parameters createImageParameters `json:"parameters" binding:"required"` } type createImageInstance struct { - Prompt string `json:"prompt"` + Prompt string `json:"prompt"` + ReferenceImages []ReferenceImage `json:"referenceImages,omitempty"` + Image *promptImage `json:"image,omitempty"` // Keeping for backward compatibility +} + +// ReferenceImage represents a reference image for the Imagen edit API +// +// https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagen-3.0-capability-001?project=ai-ca-447000 +type ReferenceImage struct { + ReferenceType string `json:"referenceType" binding:"required,oneof=REFERENCE_TYPE_RAW REFERENCE_TYPE_MASK"` + ReferenceId int `json:"referenceId"` + ReferenceImage ReferenceImageData `json:"referenceImage"` + // MaskImageConfig is used when ReferenceType is "REFERENCE_TYPE_MASK", + // to provide a mask image for the reference image. + MaskImageConfig *MaskImageConfig `json:"maskImageConfig,omitempty"` +} + +// ReferenceImageData contains the actual image data +type ReferenceImageData struct { + BytesBase64Encoded string `json:"bytesBase64Encoded,omitempty"` + GcsUri *string `json:"gcsUri,omitempty"` + MimeType *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"` +} + +// MaskImageConfig specifies how to use the mask image +type MaskImageConfig struct { + // MaskMode is used to mask mode for mask editing. + // Set MASK_MODE_USER_PROVIDED for input user provided mask in the B64_MASK_IMAGE, + // MASK_MODE_BACKGROUND for automatically mask out background without user provided mask, + // MASK_MODE_SEMANTIC for automatically generating semantic object masks by + // specifying a list of object class IDs in maskClasses. + MaskMode string `json:"maskMode" binding:"required,oneof=MASK_MODE_USER_PROVIDED MASK_MODE_BACKGROUND MASK_MODE_SEMANTIC"` + MaskClasses []int `json:"maskClasses,omitempty"` // Object class IDs when maskMode is MASK_MODE_SEMANTIC + Dilation *float64 `json:"dilation,omitempty"` // Determines the dilation percentage of the mask provided. Min: 0, Max: 1, Default: 0.03 +} + +// promptImage is the image to be used as a prompt for the Imagen API. +// It can be either a base64 encoded image or a GCS URI. +type promptImage struct { + BytesBase64Encoded *string `json:"bytesBase64Encoded,omitempty"` + GcsUri *string `json:"gcsUri,omitempty"` + MimeType *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"` } type createImageParameters struct { - SampleCount int `json:"sample_count" binding:"required,min=1"` + SampleCount int `json:"sampleCount" binding:"required,min=1"` + Mode *string `json:"mode,omitempty" binding:"omitempty,oneof=upscaled"` + // EditMode set edit mode for mask editing. + // Set EDIT_MODE_INPAINT_REMOVAL for inpainting removal, + // EDIT_MODE_INPAINT_INSERTION for inpainting insert, + // EDIT_MODE_OUTPAINT for outpainting, + // EDIT_MODE_BGSWAP for background swap. + EditMode *string `json:"editMode,omitempty" binding:"omitempty,oneof=EDIT_MODE_INPAINT_REMOVAL EDIT_MODE_INPAINT_INSERTION EDIT_MODE_OUTPAINT EDIT_MODE_BGSWAP"` + UpscaleConfig *upscaleConfig `json:"upscaleConfig,omitempty"` + Seed *int64 `json:"seed,omitempty"` } +type upscaleConfig struct { + // UpscaleFactor is the factor to which the image will be upscaled. + // If not specified, the upscale factor will be determined from + // the longer side of the input image and sampleImageSize. + // Available values: x2 or x4 . + UpscaleFactor *string `json:"upscaleFactor,omitempty" binding:"omitempty,oneof=2x 4x"` +} + +// CreateImageResponse is the response body for the Imagen API. type CreateImageResponse struct { Predictions []createImageResponsePrediction `json:"predictions"` } @@ -21,3 +83,10 @@ type createImageResponsePrediction struct { MimeType string `json:"mimeType"` BytesBase64Encoded string `json:"bytesBase64Encoded"` } + +// VQARequest is the response body for the Visual Question Answering API. +// +// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/visual-question-answering +type VQAResponse struct { + Predictions []string `json:"predictions"` +} diff --git a/relay/model/general.go b/relay/model/general.go index fd4e5641..48ea8461 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -1,6 +1,10 @@ package model -import "github.com/songquanpeng/one-api/relay/adaptor/openrouter" +import ( + "mime/multipart" + + "github.com/songquanpeng/one-api/relay/adaptor/openrouter" +) type ResponseFormat struct { Type string `json:"type,omitempty"` @@ -142,3 +146,19 @@ func (r GeneralOpenAIRequest) ParseInput() []string { } return input } + +// OpenaiImageEditRequest is the request body for the OpenAI image edit API. +type OpenaiImageEditRequest struct { + Image *multipart.FileHeader `json:"image" form:"image" binding:"required"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"` + Model string `json:"model" form:"model" binding:"required"` + N int `json:"n" form:"n" binding:"min=0,max=10"` + Size string `json:"size" form:"size"` + ResponseFormat string `json:"response_format" form:"response_format"` + // ------------------------------------- + // Imagen-3 + // ------------------------------------- + EditMode *string `json:"edit_mode" form:"edit_mode"` + MaskMode *string `json:"mask_mode" form:"mask_mode"` +}