mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: implement image editing request handling and response conversion for Imagen API
This commit is contained in:
		@@ -1,13 +1,13 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
 
 | 
			
		||||
@@ -132,12 +132,14 @@ type EmbeddingResponse struct {
 | 
			
		||||
	model.Usage `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageData represents an image in the response
 | 
			
		||||
type ImageData struct {
 | 
			
		||||
	Url           string `json:"url,omitempty"`
 | 
			
		||||
	B64Json       string `json:"b64_json,omitempty"`
 | 
			
		||||
	RevisedPrompt string `json:"revised_prompt,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageResponse represents the response structure for image generations
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int64       `json:"created"`
 | 
			
		||||
	Data    []ImageData `json:"data"`
 | 
			
		||||
 
 | 
			
		||||
@@ -73,12 +73,12 @@ func convertImageRemixRequest(c *gin.Context) (any, error) {
 | 
			
		||||
	}
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
 | 
			
		||||
	rawReq := new(OpenaiImageEditRequest)
 | 
			
		||||
	rawReq := new(model.OpenaiImageEditRequest)
 | 
			
		||||
	if err := c.ShouldBind(rawReq); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "parse image edit form")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rawReq.toFluxRemixRequest()
 | 
			
		||||
	return Convert2FluxRemixRequest(rawReq)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertRequest converts the request to the format that the target API expects.
 | 
			
		||||
 
 | 
			
		||||
@@ -6,22 +6,12 @@ import (
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/png"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OpenaiImageEditRequest struct {
 | 
			
		||||
	Image          *multipart.FileHeader `json:"image" form:"image" binding:"required"`
 | 
			
		||||
	Prompt         string                `json:"prompt" form:"prompt" binding:"required"`
 | 
			
		||||
	Mask           *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
 | 
			
		||||
	Model          string                `json:"model" form:"model" binding:"required"`
 | 
			
		||||
	N              int                   `json:"n" form:"n" binding:"min=0,max=10"`
 | 
			
		||||
	Size           string                `json:"size" form:"size"`
 | 
			
		||||
	ResponseFormat string                `json:"response_format" form:"response_format"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
 | 
			
		||||
//
 | 
			
		||||
// Note that the mask formats of OpenAI and Flux are different:
 | 
			
		||||
@@ -31,14 +21,14 @@ type OpenaiImageEditRequest struct {
 | 
			
		||||
//
 | 
			
		||||
// Both OpenAI's Image and Mask are browser-native ImageData,
 | 
			
		||||
// which need to be converted to base64 dataURI format.
 | 
			
		||||
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
 | 
			
		||||
	if r.ResponseFormat != "b64_json" {
 | 
			
		||||
func Convert2FluxRemixRequest(req *model.OpenaiImageEditRequest) (*InpaintingImageByFlusReplicateRequest, error) {
 | 
			
		||||
	if req.ResponseFormat != "b64_json" {
 | 
			
		||||
		return nil, errors.New("response_format must be b64_json for replicate models")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fluxReq := &InpaintingImageByFlusReplicateRequest{
 | 
			
		||||
		Input: FluxInpaintingInput{
 | 
			
		||||
			Prompt:           r.Prompt,
 | 
			
		||||
			Prompt:           req.Prompt,
 | 
			
		||||
			Seed:             int(time.Now().UnixNano()),
 | 
			
		||||
			Steps:            30,
 | 
			
		||||
			Guidance:         3,
 | 
			
		||||
@@ -48,7 +38,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	imgFile, err := r.Image.Open()
 | 
			
		||||
	imgFile, err := req.Image.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "open image file")
 | 
			
		||||
	}
 | 
			
		||||
@@ -58,7 +48,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep
 | 
			
		||||
		return nil, errors.Wrap(err, "read image file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	maskFile, err := r.Mask.Open()
 | 
			
		||||
	maskFile, err := req.Mask.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "open mask file")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -50,7 +51,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
 | 
			
		||||
	maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	req := &OpenaiImageEditRequest{
 | 
			
		||||
	req := &model.OpenaiImageEditRequest{
 | 
			
		||||
		Image:          imgFileHeader,
 | 
			
		||||
		Mask:           maskFileHeader,
 | 
			
		||||
		Prompt:         "Test prompt",
 | 
			
		||||
@@ -58,7 +59,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
 | 
			
		||||
		ResponseFormat: "b64_json",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fluxReq, err := req.toFluxRemixRequest()
 | 
			
		||||
	fluxReq, err := Convert2FluxRemixRequest(req)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.NotNil(t, fluxReq)
 | 
			
		||||
	require.Equal(t, req.Prompt, fluxReq.Input.Prompt)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package imagen
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -15,16 +16,28 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	// create
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// generate
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"imagen-3.0-generate-001", "imagen-3.0-generate-002",
 | 
			
		||||
	"imagen-3.0-fast-generate-001",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// edit
 | 
			
		||||
	// "imagen-3.0-capability-001",  // not supported yet
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"imagen-3.0-capability-001",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
	// No initialization needed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	meta := meta.GetByContext(c)
 | 
			
		||||
 | 
			
		||||
@@ -32,19 +45,91 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageReques
 | 
			
		||||
		return nil, errors.New("only support b64_json response format")
 | 
			
		||||
	}
 | 
			
		||||
	if request.N <= 0 {
 | 
			
		||||
		return nil, errors.New("n must be greater than 0")
 | 
			
		||||
		request.N = 1 // Default to 1 if not specified
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		return convertImageCreateRequest(request)
 | 
			
		||||
	case relaymode.ImagesEdits:
 | 
			
		||||
		return nil, errors.New("not implemented")
 | 
			
		||||
		switch c.ContentType() {
 | 
			
		||||
		// case "application/json":
 | 
			
		||||
		// 	return ConvertJsonImageEditRequest(c)
 | 
			
		||||
		case "multipart/form-data":
 | 
			
		||||
			return ConvertMultipartImageEditRequest(c)
 | 
			
		||||
		default:
 | 
			
		||||
			return nil, errors.New("unsupported content type for image edit")
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("not implemented")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	resp.Body.Close()
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
 | 
			
		||||
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesEdits:
 | 
			
		||||
		return HandleImageEdit(c, resp)
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		return nil, handleImageGeneration(c, resp, respBody)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, openai.ErrorWrapper(errors.New("unsupported mode"), "unsupported_mode", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleImageGeneration(c *gin.Context, resp *http.Response, respBody []byte) *model.ErrorWithStatusCode {
 | 
			
		||||
	var imageResponse CreateImageResponse
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := json.Unmarshal(respBody, &imageResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Convert to OpenAI format
 | 
			
		||||
	openaiResp := openai.ImageResponse{
 | 
			
		||||
		Created: time.Now().Unix(),
 | 
			
		||||
		Data:    make([]openai.ImageData, 0, len(imageResponse.Predictions)),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, prediction := range imageResponse.Predictions {
 | 
			
		||||
		openaiResp.Data = append(openaiResp.Data, openai.ImageData{
 | 
			
		||||
			B64Json: prediction.BytesBase64Encoded,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBytes, err := json.Marshal(openaiResp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(http.StatusOK)
 | 
			
		||||
	_, err = c.Writer.Write(respBytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return "vertex_ai_imagen"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return CreateImageRequest{
 | 
			
		||||
		Instances: []createImageInstance{
 | 
			
		||||
@@ -57,55 +142,3 @@ func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
		},
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(
 | 
			
		||||
			errors.Wrap(err, "failed to read response body"),
 | 
			
		||||
			"read_response_body",
 | 
			
		||||
			http.StatusInternalServerError,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return nil, openai.ErrorWrapper(
 | 
			
		||||
			errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
 | 
			
		||||
			"upstream_response",
 | 
			
		||||
			http.StatusInternalServerError,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	imagenResp := new(CreateImageResponse)
 | 
			
		||||
	if err := json.Unmarshal(respBody, imagenResp); err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(
 | 
			
		||||
			errors.Wrap(err, "failed to decode response body"),
 | 
			
		||||
			"unmarshal_upstream_response",
 | 
			
		||||
			http.StatusInternalServerError,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(imagenResp.Predictions) == 0 {
 | 
			
		||||
		return nil, openai.ErrorWrapper(
 | 
			
		||||
			errors.New("empty predictions"),
 | 
			
		||||
			"empty_predictions",
 | 
			
		||||
			http.StatusInternalServerError,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oaiResp := openai.ImageResponse{
 | 
			
		||||
		Created: time.Now().Unix(),
 | 
			
		||||
	}
 | 
			
		||||
	for _, prediction := range imagenResp.Predictions {
 | 
			
		||||
		oaiResp.Data = append(oaiResp.Data, openai.ImageData{
 | 
			
		||||
			B64Json: prediction.BytesBase64Encoded,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(http.StatusOK, oaiResp)
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										168
									
								
								relay/adaptor/vertexai/imagen/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								relay/adaptor/vertexai/imagen/image.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,168 @@
 | 
			
		||||
package imagen
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ConvertImageEditRequest handles the conversion from multipart form to Imagen request format
 | 
			
		||||
func ConvertMultipartImageEditRequest(c *gin.Context) (*CreateImageRequest, error) {
 | 
			
		||||
	// Recover request body for binding
 | 
			
		||||
	requestBody, err := common.GetRequestBody(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "get request body")
 | 
			
		||||
	}
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
 | 
			
		||||
	// Parse the form
 | 
			
		||||
	var rawReq model.OpenaiImageEditRequest
 | 
			
		||||
	if err := c.ShouldBind(&rawReq); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "parse image edit form")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate response format
 | 
			
		||||
	if rawReq.ResponseFormat != "b64_json" {
 | 
			
		||||
		return nil, errors.New("response_format must be b64_json for Imagen models")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set default N if not provided
 | 
			
		||||
	if rawReq.N <= 0 {
 | 
			
		||||
		rawReq.N = 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set default edit mode if not provided
 | 
			
		||||
	editMode := "EDIT_MODE_INPAINT_INSERTION"
 | 
			
		||||
	if rawReq.EditMode != nil {
 | 
			
		||||
		editMode = *rawReq.EditMode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set default mask mode if not provided
 | 
			
		||||
	maskMode := "MASK_MODE_USER_PROVIDED"
 | 
			
		||||
	if rawReq.MaskMode != nil {
 | 
			
		||||
		maskMode = *rawReq.MaskMode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Process the image file
 | 
			
		||||
	imgFile, err := rawReq.Image.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "open image file")
 | 
			
		||||
	}
 | 
			
		||||
	defer imgFile.Close()
 | 
			
		||||
 | 
			
		||||
	imgData, err := io.ReadAll(imgFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "read image file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Process the mask file
 | 
			
		||||
	maskFile, err := rawReq.Mask.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "open mask file")
 | 
			
		||||
	}
 | 
			
		||||
	defer maskFile.Close()
 | 
			
		||||
 | 
			
		||||
	maskData, err := io.ReadAll(maskFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "read mask file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Convert to base64
 | 
			
		||||
	imgBase64 := base64.StdEncoding.EncodeToString(imgData)
 | 
			
		||||
	maskBase64 := base64.StdEncoding.EncodeToString(maskData)
 | 
			
		||||
 | 
			
		||||
	// Create the request
 | 
			
		||||
	req := &CreateImageRequest{
 | 
			
		||||
		Instances: []createImageInstance{
 | 
			
		||||
			{
 | 
			
		||||
				Prompt: rawReq.Prompt,
 | 
			
		||||
				ReferenceImages: []ReferenceImage{
 | 
			
		||||
					{
 | 
			
		||||
						ReferenceType: "REFERENCE_TYPE_RAW",
 | 
			
		||||
						ReferenceId:   1,
 | 
			
		||||
						ReferenceImage: ReferenceImageData{
 | 
			
		||||
							BytesBase64Encoded: imgBase64,
 | 
			
		||||
						},
 | 
			
		||||
					},
 | 
			
		||||
					{
 | 
			
		||||
						ReferenceType: "REFERENCE_TYPE_MASK",
 | 
			
		||||
						ReferenceId:   2,
 | 
			
		||||
						ReferenceImage: ReferenceImageData{
 | 
			
		||||
							BytesBase64Encoded: maskBase64,
 | 
			
		||||
						},
 | 
			
		||||
						MaskImageConfig: &MaskImageConfig{
 | 
			
		||||
							MaskMode: maskMode,
 | 
			
		||||
						},
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		Parameters: createImageParameters{
 | 
			
		||||
			SampleCount: rawReq.N,
 | 
			
		||||
			EditMode:    &editMode,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return req, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HandleImageEdit processes an image edit response from Imagen API
 | 
			
		||||
func HandleImageEdit(c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) {
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return nil, openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var imageResponse CreateImageResponse
 | 
			
		||||
	err = json.Unmarshal(respBody, &imageResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Convert to OpenAI format
 | 
			
		||||
	openaiResp := openai.ImageResponse{
 | 
			
		||||
		Created: time.Now().Unix(),
 | 
			
		||||
		Data:    make([]openai.ImageData, 0, len(imageResponse.Predictions)),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, prediction := range imageResponse.Predictions {
 | 
			
		||||
		openaiResp.Data = append(openaiResp.Data, openai.ImageData{
 | 
			
		||||
			B64Json: prediction.BytesBase64Encoded,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBytes, err := json.Marshal(openaiResp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(http.StatusOK)
 | 
			
		||||
	_, err = c.Writer.Write(respBytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create usage data (minimal as this API doesn't provide token counts)
 | 
			
		||||
	usage := &model.Usage{
 | 
			
		||||
		PromptTokens:     0,
 | 
			
		||||
		CompletionTokens: 0,
 | 
			
		||||
		TotalTokens:      0,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return usage, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,18 +1,80 @@
 | 
			
		||||
package imagen
 | 
			
		||||
 | 
			
		||||
// CreateImageRequest is the request body for the Imagen API.
 | 
			
		||||
//
 | 
			
		||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api
 | 
			
		||||
type CreateImageRequest struct {
 | 
			
		||||
	Instances  []createImageInstance `json:"instances" binding:"required,min=1"`
 | 
			
		||||
	Parameters createImageParameters `json:"parameters" binding:"required"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type createImageInstance struct {
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	Prompt          string           `json:"prompt"`
 | 
			
		||||
	ReferenceImages []ReferenceImage `json:"referenceImages,omitempty"`
 | 
			
		||||
	Image           *promptImage     `json:"image,omitempty"` // Keeping for backward compatibility
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReferenceImage represents a reference image for the Imagen edit API
 | 
			
		||||
//
 | 
			
		||||
// https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagen-3.0-capability-001?project=ai-ca-447000
 | 
			
		||||
type ReferenceImage struct {
 | 
			
		||||
	ReferenceType  string             `json:"referenceType" binding:"required,oneof=REFERENCE_TYPE_RAW REFERENCE_TYPE_MASK"`
 | 
			
		||||
	ReferenceId    int                `json:"referenceId"`
 | 
			
		||||
	ReferenceImage ReferenceImageData `json:"referenceImage"`
 | 
			
		||||
	// MaskImageConfig is used when ReferenceType is "REFERENCE_TYPE_MASK",
 | 
			
		||||
	// to provide a mask image for the reference image.
 | 
			
		||||
	MaskImageConfig *MaskImageConfig `json:"maskImageConfig,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReferenceImageData contains the actual image data
 | 
			
		||||
type ReferenceImageData struct {
 | 
			
		||||
	BytesBase64Encoded string  `json:"bytesBase64Encoded,omitempty"`
 | 
			
		||||
	GcsUri             *string `json:"gcsUri,omitempty"`
 | 
			
		||||
	MimeType           *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MaskImageConfig specifies how to use the mask image
 | 
			
		||||
type MaskImageConfig struct {
 | 
			
		||||
	// MaskMode is used to mask mode for mask editing.
 | 
			
		||||
	// Set MASK_MODE_USER_PROVIDED for input user provided mask in the B64_MASK_IMAGE,
 | 
			
		||||
	// MASK_MODE_BACKGROUND for automatically mask out background without user provided mask,
 | 
			
		||||
	// MASK_MODE_SEMANTIC for automatically generating semantic object masks by
 | 
			
		||||
	// specifying a list of object class IDs in maskClasses.
 | 
			
		||||
	MaskMode    string   `json:"maskMode" binding:"required,oneof=MASK_MODE_USER_PROVIDED MASK_MODE_BACKGROUND MASK_MODE_SEMANTIC"`
 | 
			
		||||
	MaskClasses []int    `json:"maskClasses,omitempty"` // Object class IDs when maskMode is MASK_MODE_SEMANTIC
 | 
			
		||||
	Dilation    *float64 `json:"dilation,omitempty"`    // Determines the dilation percentage of the mask provided. Min: 0, Max: 1, Default: 0.03
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// promptImage is the image to be used as a prompt for the Imagen API.
 | 
			
		||||
// It can be either a base64 encoded image or a GCS URI.
 | 
			
		||||
type promptImage struct {
 | 
			
		||||
	BytesBase64Encoded *string `json:"bytesBase64Encoded,omitempty"`
 | 
			
		||||
	GcsUri             *string `json:"gcsUri,omitempty"`
 | 
			
		||||
	MimeType           *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type createImageParameters struct {
 | 
			
		||||
	SampleCount int `json:"sample_count" binding:"required,min=1"`
 | 
			
		||||
	SampleCount int     `json:"sampleCount" binding:"required,min=1"`
 | 
			
		||||
	Mode        *string `json:"mode,omitempty" binding:"omitempty,oneof=upscaled"`
 | 
			
		||||
	// EditMode set edit mode for mask editing.
 | 
			
		||||
	// Set EDIT_MODE_INPAINT_REMOVAL for inpainting removal,
 | 
			
		||||
	// EDIT_MODE_INPAINT_INSERTION for inpainting insert,
 | 
			
		||||
	// EDIT_MODE_OUTPAINT for outpainting,
 | 
			
		||||
	// EDIT_MODE_BGSWAP for background swap.
 | 
			
		||||
	EditMode      *string        `json:"editMode,omitempty" binding:"omitempty,oneof=EDIT_MODE_INPAINT_REMOVAL EDIT_MODE_INPAINT_INSERTION EDIT_MODE_OUTPAINT EDIT_MODE_BGSWAP"`
 | 
			
		||||
	UpscaleConfig *upscaleConfig `json:"upscaleConfig,omitempty"`
 | 
			
		||||
	Seed          *int64         `json:"seed,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type upscaleConfig struct {
 | 
			
		||||
	// UpscaleFactor is the factor to which the image will be upscaled.
 | 
			
		||||
	// If not specified, the upscale factor will be determined from
 | 
			
		||||
	// the longer side of the input image and sampleImageSize.
 | 
			
		||||
	// Available values: x2 or x4 .
 | 
			
		||||
	UpscaleFactor *string `json:"upscaleFactor,omitempty" binding:"omitempty,oneof=2x 4x"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateImageResponse is the response body for the Imagen API.
 | 
			
		||||
type CreateImageResponse struct {
 | 
			
		||||
	Predictions []createImageResponsePrediction `json:"predictions"`
 | 
			
		||||
}
 | 
			
		||||
@@ -21,3 +83,10 @@ type createImageResponsePrediction struct {
 | 
			
		||||
	MimeType           string `json:"mimeType"`
 | 
			
		||||
	BytesBase64Encoded string `json:"bytesBase64Encoded"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// VQARequest is the response body for the Visual Question Answering API.
 | 
			
		||||
//
 | 
			
		||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/visual-question-answering
 | 
			
		||||
type VQAResponse struct {
 | 
			
		||||
	Predictions []string `json:"predictions"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,10 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/adaptor/openrouter"
 | 
			
		||||
import (
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openrouter"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ResponseFormat struct {
 | 
			
		||||
	Type       string      `json:"type,omitempty"`
 | 
			
		||||
@@ -142,3 +146,19 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
			
		||||
	}
 | 
			
		||||
	return input
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenaiImageEditRequest is the request body for the OpenAI image edit API.
 | 
			
		||||
type OpenaiImageEditRequest struct {
 | 
			
		||||
	Image          *multipart.FileHeader `json:"image" form:"image" binding:"required"`
 | 
			
		||||
	Prompt         string                `json:"prompt" form:"prompt" binding:"required"`
 | 
			
		||||
	Mask           *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
 | 
			
		||||
	Model          string                `json:"model" form:"model" binding:"required"`
 | 
			
		||||
	N              int                   `json:"n" form:"n" binding:"min=0,max=10"`
 | 
			
		||||
	Size           string                `json:"size" form:"size"`
 | 
			
		||||
	ResponseFormat string                `json:"response_format" form:"response_format"`
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// Imagen-3
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	EditMode *string `json:"edit_mode" form:"edit_mode"`
 | 
			
		||||
	MaskMode *string `json:"mask_mode" form:"mask_mode"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user