feat: implement image editing request handling and response conversion for Imagen API

This commit is contained in:
Laisky.Cai 2025-03-16 14:21:38 +00:00
parent fa794e7bd5
commit 580fec6359
9 changed files with 363 additions and 80 deletions

View File

@ -1,13 +1,13 @@
package auth
import (
"strings"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"

View File

@ -132,12 +132,14 @@ type EmbeddingResponse struct {
model.Usage `json:"usage"`
}
// ImageData represents an image in the response
type ImageData struct {
Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageResponse represents the response structure for image generations
type ImageResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`

View File

@ -73,12 +73,12 @@ func convertImageRemixRequest(c *gin.Context) (any, error) {
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
rawReq := new(OpenaiImageEditRequest)
rawReq := new(model.OpenaiImageEditRequest)
if err := c.ShouldBind(rawReq); err != nil {
return nil, errors.Wrap(err, "parse image edit form")
}
return rawReq.toFluxRemixRequest()
return Convert2FluxRemixRequest(rawReq)
}
// ConvertRequest converts the request to the format that the target API expects.

View File

@ -6,22 +6,12 @@ import (
"image"
"image/png"
"io"
"mime/multipart"
"time"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/model"
)
type OpenaiImageEditRequest struct {
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
Prompt string `json:"prompt" form:"prompt" binding:"required"`
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
Model string `json:"model" form:"model" binding:"required"`
N int `json:"n" form:"n" binding:"min=0,max=10"`
Size string `json:"size" form:"size"`
ResponseFormat string `json:"response_format" form:"response_format"`
}
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
//
// Note that the mask formats of OpenAI and Flux are different:
@ -31,14 +21,14 @@ type OpenaiImageEditRequest struct {
//
// Both OpenAI's Image and Mask are browser-native ImageData,
// which need to be converted to base64 dataURI format.
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
if r.ResponseFormat != "b64_json" {
func Convert2FluxRemixRequest(req *model.OpenaiImageEditRequest) (*InpaintingImageByFlusReplicateRequest, error) {
if req.ResponseFormat != "b64_json" {
return nil, errors.New("response_format must be b64_json for replicate models")
}
fluxReq := &InpaintingImageByFlusReplicateRequest{
Input: FluxInpaintingInput{
Prompt: r.Prompt,
Prompt: req.Prompt,
Seed: int(time.Now().UnixNano()),
Steps: 30,
Guidance: 3,
@ -48,7 +38,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep
},
}
imgFile, err := r.Image.Open()
imgFile, err := req.Image.Open()
if err != nil {
return nil, errors.Wrap(err, "open image file")
}
@ -58,7 +48,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep
return nil, errors.Wrap(err, "read image file")
}
maskFile, err := r.Mask.Open()
maskFile, err := req.Mask.Open()
if err != nil {
return nil, errors.Wrap(err, "open mask file")
}

View File

@ -10,6 +10,7 @@ import (
"net/http"
"testing"
"github.com/songquanpeng/one-api/relay/model"
"github.com/stretchr/testify/require"
)
@ -50,7 +51,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
require.NoError(t, err)
req := &OpenaiImageEditRequest{
req := &model.OpenaiImageEditRequest{
Image: imgFileHeader,
Mask: maskFileHeader,
Prompt: "Test prompt",
@ -58,7 +59,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
ResponseFormat: "b64_json",
}
fluxReq, err := req.toFluxRemixRequest()
fluxReq, err := Convert2FluxRemixRequest(req)
require.NoError(t, err)
require.NotNil(t, fluxReq)
require.Equal(t, req.Prompt, fluxReq.Input.Prompt)

View File

@ -1,6 +1,7 @@
package imagen
import (
"bytes"
"encoding/json"
"io"
"net/http"
@ -15,16 +16,28 @@ import (
)
var ModelList = []string{
// create
// -------------------------------------
// generate
// -------------------------------------
"imagen-3.0-generate-001", "imagen-3.0-generate-002",
"imagen-3.0-fast-generate-001",
// -------------------------------------
// edit
// "imagen-3.0-capability-001", // not supported yet
// -------------------------------------
"imagen-3.0-capability-001",
}
type Adaptor struct {
}
func (a *Adaptor) Init(meta *meta.Meta) {
// No initialization needed
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
meta := meta.GetByContext(c)
@ -32,19 +45,91 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageReques
return nil, errors.New("only support b64_json response format")
}
if request.N <= 0 {
return nil, errors.New("n must be greater than 0")
request.N = 1 // Default to 1 if not specified
}
switch meta.Mode {
case relaymode.ImagesGenerations:
return convertImageCreateRequest(request)
case relaymode.ImagesEdits:
return nil, errors.New("not implemented")
switch c.ContentType() {
// case "application/json":
// return ConvertJsonImageEditRequest(c)
case "multipart/form-data":
return ConvertMultipartImageEditRequest(c)
default:
return nil, errors.New("unsupported content type for image edit")
}
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
switch meta.Mode {
case relaymode.ImagesEdits:
return HandleImageEdit(c, resp)
case relaymode.ImagesGenerations:
return nil, handleImageGeneration(c, resp, respBody)
default:
return nil, openai.ErrorWrapper(errors.New("unsupported mode"), "unsupported_mode", http.StatusBadRequest)
}
}
func handleImageGeneration(c *gin.Context, resp *http.Response, respBody []byte) *model.ErrorWithStatusCode {
var imageResponse CreateImageResponse
if resp.StatusCode != http.StatusOK {
return openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
}
err := json.Unmarshal(respBody, &imageResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// Convert to OpenAI format
openaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)),
}
for _, prediction := range imageResponse.Predictions {
openaiResp.Data = append(openaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
respBytes, err := json.Marshal(openaiResp)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(respBytes)
if err != nil {
return openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
}
return nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "vertex_ai_imagen"
}
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
return CreateImageRequest{
Instances: []createImageInstance{
@ -57,55 +142,3 @@ func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to read response body"),
"read_response_body",
http.StatusInternalServerError,
)
}
if resp.StatusCode != http.StatusOK {
return nil, openai.ErrorWrapper(
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
"upstream_response",
http.StatusInternalServerError,
)
}
imagenResp := new(CreateImageResponse)
if err := json.Unmarshal(respBody, imagenResp); err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to decode response body"),
"unmarshal_upstream_response",
http.StatusInternalServerError,
)
}
if len(imagenResp.Predictions) == 0 {
return nil, openai.ErrorWrapper(
errors.New("empty predictions"),
"empty_predictions",
http.StatusInternalServerError,
)
}
oaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
}
for _, prediction := range imagenResp.Predictions {
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
c.JSON(http.StatusOK, oaiResp)
return nil, nil
}

View File

@ -0,0 +1,168 @@
package imagen
import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
// ConvertImageEditRequest handles the conversion from multipart form to Imagen request format
func ConvertMultipartImageEditRequest(c *gin.Context) (*CreateImageRequest, error) {
// Recover request body for binding
requestBody, err := common.GetRequestBody(c)
if err != nil {
return nil, errors.Wrap(err, "get request body")
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
// Parse the form
var rawReq model.OpenaiImageEditRequest
if err := c.ShouldBind(&rawReq); err != nil {
return nil, errors.Wrap(err, "parse image edit form")
}
// Validate response format
if rawReq.ResponseFormat != "b64_json" {
return nil, errors.New("response_format must be b64_json for Imagen models")
}
// Set default N if not provided
if rawReq.N <= 0 {
rawReq.N = 1
}
// Set default edit mode if not provided
editMode := "EDIT_MODE_INPAINT_INSERTION"
if rawReq.EditMode != nil {
editMode = *rawReq.EditMode
}
// Set default mask mode if not provided
maskMode := "MASK_MODE_USER_PROVIDED"
if rawReq.MaskMode != nil {
maskMode = *rawReq.MaskMode
}
// Process the image file
imgFile, err := rawReq.Image.Open()
if err != nil {
return nil, errors.Wrap(err, "open image file")
}
defer imgFile.Close()
imgData, err := io.ReadAll(imgFile)
if err != nil {
return nil, errors.Wrap(err, "read image file")
}
// Process the mask file
maskFile, err := rawReq.Mask.Open()
if err != nil {
return nil, errors.Wrap(err, "open mask file")
}
defer maskFile.Close()
maskData, err := io.ReadAll(maskFile)
if err != nil {
return nil, errors.Wrap(err, "read mask file")
}
// Convert to base64
imgBase64 := base64.StdEncoding.EncodeToString(imgData)
maskBase64 := base64.StdEncoding.EncodeToString(maskData)
// Create the request
req := &CreateImageRequest{
Instances: []createImageInstance{
{
Prompt: rawReq.Prompt,
ReferenceImages: []ReferenceImage{
{
ReferenceType: "REFERENCE_TYPE_RAW",
ReferenceId: 1,
ReferenceImage: ReferenceImageData{
BytesBase64Encoded: imgBase64,
},
},
{
ReferenceType: "REFERENCE_TYPE_MASK",
ReferenceId: 2,
ReferenceImage: ReferenceImageData{
BytesBase64Encoded: maskBase64,
},
MaskImageConfig: &MaskImageConfig{
MaskMode: maskMode,
},
},
},
},
},
Parameters: createImageParameters{
SampleCount: rawReq.N,
EditMode: &editMode,
},
}
return req, nil
}
// HandleImageEdit processes an image edit response from Imagen API
func HandleImageEdit(c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
}
var imageResponse CreateImageResponse
err = json.Unmarshal(respBody, &imageResponse)
if err != nil {
return nil, openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// Convert to OpenAI format
openaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)),
}
for _, prediction := range imageResponse.Predictions {
openaiResp.Data = append(openaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
respBytes, err := json.Marshal(openaiResp)
if err != nil {
return nil, openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(respBytes)
if err != nil {
return nil, openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
}
// Create usage data (minimal as this API doesn't provide token counts)
usage := &model.Usage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}
return usage, nil
}

View File

@ -1,18 +1,80 @@
package imagen
// CreateImageRequest is the request body for the Imagen API.
//
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api
type CreateImageRequest struct {
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
Parameters createImageParameters `json:"parameters" binding:"required"`
}
type createImageInstance struct {
Prompt string `json:"prompt"`
Prompt string `json:"prompt"`
ReferenceImages []ReferenceImage `json:"referenceImages,omitempty"`
Image *promptImage `json:"image,omitempty"` // Keeping for backward compatibility
}
// ReferenceImage represents a reference image for the Imagen edit API
//
// https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagen-3.0-capability-001?project=ai-ca-447000
type ReferenceImage struct {
ReferenceType string `json:"referenceType" binding:"required,oneof=REFERENCE_TYPE_RAW REFERENCE_TYPE_MASK"`
ReferenceId int `json:"referenceId"`
ReferenceImage ReferenceImageData `json:"referenceImage"`
// MaskImageConfig is used when ReferenceType is "REFERENCE_TYPE_MASK",
// to provide a mask image for the reference image.
MaskImageConfig *MaskImageConfig `json:"maskImageConfig,omitempty"`
}
// ReferenceImageData contains the actual image data
type ReferenceImageData struct {
BytesBase64Encoded string `json:"bytesBase64Encoded,omitempty"`
GcsUri *string `json:"gcsUri,omitempty"`
MimeType *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"`
}
// MaskImageConfig specifies how to use the mask image
type MaskImageConfig struct {
// MaskMode is used to mask mode for mask editing.
// Set MASK_MODE_USER_PROVIDED for input user provided mask in the B64_MASK_IMAGE,
// MASK_MODE_BACKGROUND for automatically mask out background without user provided mask,
// MASK_MODE_SEMANTIC for automatically generating semantic object masks by
// specifying a list of object class IDs in maskClasses.
MaskMode string `json:"maskMode" binding:"required,oneof=MASK_MODE_USER_PROVIDED MASK_MODE_BACKGROUND MASK_MODE_SEMANTIC"`
MaskClasses []int `json:"maskClasses,omitempty"` // Object class IDs when maskMode is MASK_MODE_SEMANTIC
Dilation *float64 `json:"dilation,omitempty"` // Determines the dilation percentage of the mask provided. Min: 0, Max: 1, Default: 0.03
}
// promptImage is the image to be used as a prompt for the Imagen API.
// It can be either a base64 encoded image or a GCS URI.
type promptImage struct {
BytesBase64Encoded *string `json:"bytesBase64Encoded,omitempty"`
GcsUri *string `json:"gcsUri,omitempty"`
MimeType *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"`
}
type createImageParameters struct {
SampleCount int `json:"sample_count" binding:"required,min=1"`
SampleCount int `json:"sampleCount" binding:"required,min=1"`
Mode *string `json:"mode,omitempty" binding:"omitempty,oneof=upscaled"`
// EditMode set edit mode for mask editing.
// Set EDIT_MODE_INPAINT_REMOVAL for inpainting removal,
// EDIT_MODE_INPAINT_INSERTION for inpainting insert,
// EDIT_MODE_OUTPAINT for outpainting,
// EDIT_MODE_BGSWAP for background swap.
EditMode *string `json:"editMode,omitempty" binding:"omitempty,oneof=EDIT_MODE_INPAINT_REMOVAL EDIT_MODE_INPAINT_INSERTION EDIT_MODE_OUTPAINT EDIT_MODE_BGSWAP"`
UpscaleConfig *upscaleConfig `json:"upscaleConfig,omitempty"`
Seed *int64 `json:"seed,omitempty"`
}
type upscaleConfig struct {
// UpscaleFactor is the factor to which the image will be upscaled.
// If not specified, the upscale factor will be determined from
// the longer side of the input image and sampleImageSize.
// Available values: x2 or x4 .
UpscaleFactor *string `json:"upscaleFactor,omitempty" binding:"omitempty,oneof=2x 4x"`
}
// CreateImageResponse is the response body for the Imagen API.
type CreateImageResponse struct {
Predictions []createImageResponsePrediction `json:"predictions"`
}
@ -21,3 +83,10 @@ type createImageResponsePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
}
// VQARequest is the response body for the Visual Question Answering API.
//
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/visual-question-answering
type VQAResponse struct {
Predictions []string `json:"predictions"`
}

View File

@ -1,6 +1,10 @@
package model
import "github.com/songquanpeng/one-api/relay/adaptor/openrouter"
import (
"mime/multipart"
"github.com/songquanpeng/one-api/relay/adaptor/openrouter"
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
@ -142,3 +146,19 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
return input
}
// OpenaiImageEditRequest is the request body for the OpenAI image edit API.
type OpenaiImageEditRequest struct {
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
Prompt string `json:"prompt" form:"prompt" binding:"required"`
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
Model string `json:"model" form:"model" binding:"required"`
N int `json:"n" form:"n" binding:"min=0,max=10"`
Size string `json:"size" form:"size"`
ResponseFormat string `json:"response_format" form:"response_format"`
// -------------------------------------
// Imagen-3
// -------------------------------------
EditMode *string `json:"edit_mode" form:"edit_mode"`
MaskMode *string `json:"mask_mode" form:"mask_mode"`
}