mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-19 01:56:37 +08:00
Merge branch 'feature/imagen3'
This commit is contained in:
commit
c9bc075b04
@ -1,13 +1,13 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
|
@ -132,12 +132,14 @@ type EmbeddingResponse struct {
|
|||||||
model.Usage `json:"usage"`
|
model.Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageData represents an image in the response
|
||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Url string `json:"url,omitempty"`
|
Url string `json:"url,omitempty"`
|
||||||
B64Json string `json:"b64_json,omitempty"`
|
B64Json string `json:"b64_json,omitempty"`
|
||||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageResponse represents the response structure for image generations
|
||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Data []ImageData `json:"data"`
|
Data []ImageData `json:"data"`
|
||||||
|
@ -73,12 +73,12 @@ func convertImageRemixRequest(c *gin.Context) (any, error) {
|
|||||||
}
|
}
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
rawReq := new(OpenaiImageEditRequest)
|
rawReq := new(model.OpenaiImageEditRequest)
|
||||||
if err := c.ShouldBind(rawReq); err != nil {
|
if err := c.ShouldBind(rawReq); err != nil {
|
||||||
return nil, errors.Wrap(err, "parse image edit form")
|
return nil, errors.Wrap(err, "parse image edit form")
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawReq.toFluxRemixRequest()
|
return Convert2FluxRemixRequest(rawReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertRequest converts the request to the format that the target API expects.
|
// ConvertRequest converts the request to the format that the target API expects.
|
||||||
|
@ -6,22 +6,12 @@ import (
|
|||||||
"image"
|
"image"
|
||||||
"image/png"
|
"image/png"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenaiImageEditRequest struct {
|
|
||||||
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
|
|
||||||
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
|
||||||
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
|
|
||||||
Model string `json:"model" form:"model" binding:"required"`
|
|
||||||
N int `json:"n" form:"n" binding:"min=0,max=10"`
|
|
||||||
Size string `json:"size" form:"size"`
|
|
||||||
ResponseFormat string `json:"response_format" form:"response_format"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
|
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
|
||||||
//
|
//
|
||||||
// Note that the mask formats of OpenAI and Flux are different:
|
// Note that the mask formats of OpenAI and Flux are different:
|
||||||
@ -31,14 +21,14 @@ type OpenaiImageEditRequest struct {
|
|||||||
//
|
//
|
||||||
// Both OpenAI's Image and Mask are browser-native ImageData,
|
// Both OpenAI's Image and Mask are browser-native ImageData,
|
||||||
// which need to be converted to base64 dataURI format.
|
// which need to be converted to base64 dataURI format.
|
||||||
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
|
func Convert2FluxRemixRequest(req *model.OpenaiImageEditRequest) (*InpaintingImageByFlusReplicateRequest, error) {
|
||||||
if r.ResponseFormat != "b64_json" {
|
if req.ResponseFormat != "b64_json" {
|
||||||
return nil, errors.New("response_format must be b64_json for replicate models")
|
return nil, errors.New("response_format must be b64_json for replicate models")
|
||||||
}
|
}
|
||||||
|
|
||||||
fluxReq := &InpaintingImageByFlusReplicateRequest{
|
fluxReq := &InpaintingImageByFlusReplicateRequest{
|
||||||
Input: FluxInpaintingInput{
|
Input: FluxInpaintingInput{
|
||||||
Prompt: r.Prompt,
|
Prompt: req.Prompt,
|
||||||
Seed: int(time.Now().UnixNano()),
|
Seed: int(time.Now().UnixNano()),
|
||||||
Steps: 30,
|
Steps: 30,
|
||||||
Guidance: 3,
|
Guidance: 3,
|
||||||
@ -48,7 +38,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
imgFile, err := r.Image.Open()
|
imgFile, err := req.Image.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "open image file")
|
return nil, errors.Wrap(err, "open image file")
|
||||||
}
|
}
|
||||||
@ -58,7 +48,7 @@ func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusRep
|
|||||||
return nil, errors.Wrap(err, "read image file")
|
return nil, errors.Wrap(err, "read image file")
|
||||||
}
|
}
|
||||||
|
|
||||||
maskFile, err := r.Mask.Open()
|
maskFile, err := req.Mask.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "open mask file")
|
return nil, errors.Wrap(err, "open mask file")
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
|
|||||||
maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
|
maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := &OpenaiImageEditRequest{
|
req := &model.OpenaiImageEditRequest{
|
||||||
Image: imgFileHeader,
|
Image: imgFileHeader,
|
||||||
Mask: maskFileHeader,
|
Mask: maskFileHeader,
|
||||||
Prompt: "Test prompt",
|
Prompt: "Test prompt",
|
||||||
@ -58,7 +59,7 @@ func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
|
|||||||
ResponseFormat: "b64_json",
|
ResponseFormat: "b64_json",
|
||||||
}
|
}
|
||||||
|
|
||||||
fluxReq, err := req.toFluxRemixRequest()
|
fluxReq, err := Convert2FluxRemixRequest(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, fluxReq)
|
require.NotNil(t, fluxReq)
|
||||||
require.Equal(t, req.Prompt, fluxReq.Input.Prompt)
|
require.Equal(t, req.Prompt, fluxReq.Input.Prompt)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package imagen
|
package imagen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -15,16 +16,28 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
// create
|
// -------------------------------------
|
||||||
|
// generate
|
||||||
|
// -------------------------------------
|
||||||
"imagen-3.0-generate-001", "imagen-3.0-generate-002",
|
"imagen-3.0-generate-001", "imagen-3.0-generate-002",
|
||||||
"imagen-3.0-fast-generate-001",
|
"imagen-3.0-fast-generate-001",
|
||||||
|
// -------------------------------------
|
||||||
// edit
|
// edit
|
||||||
// "imagen-3.0-capability-001", // not supported yet
|
// -------------------------------------
|
||||||
|
"imagen-3.0-capability-001",
|
||||||
}
|
}
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
|
// No initialization needed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
meta := meta.GetByContext(c)
|
meta := meta.GetByContext(c)
|
||||||
|
|
||||||
@ -32,19 +45,91 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageReques
|
|||||||
return nil, errors.New("only support b64_json response format")
|
return nil, errors.New("only support b64_json response format")
|
||||||
}
|
}
|
||||||
if request.N <= 0 {
|
if request.N <= 0 {
|
||||||
return nil, errors.New("n must be greater than 0")
|
request.N = 1 // Default to 1 if not specified
|
||||||
}
|
}
|
||||||
|
|
||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.ImagesGenerations:
|
case relaymode.ImagesGenerations:
|
||||||
return convertImageCreateRequest(request)
|
return convertImageCreateRequest(request)
|
||||||
case relaymode.ImagesEdits:
|
case relaymode.ImagesEdits:
|
||||||
return nil, errors.New("not implemented")
|
switch c.ContentType() {
|
||||||
|
// case "application/json":
|
||||||
|
// return ConvertJsonImageEditRequest(c)
|
||||||
|
case "multipart/form-data":
|
||||||
|
return ConvertMultipartImageEditRequest(c)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported content type for image edit")
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
|
||||||
|
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesEdits:
|
||||||
|
return HandleImageEdit(c, resp)
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
return nil, handleImageGeneration(c, resp, respBody)
|
||||||
|
default:
|
||||||
|
return nil, openai.ErrorWrapper(errors.New("unsupported mode"), "unsupported_mode", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleImageGeneration(c *gin.Context, resp *http.Response, respBody []byte) *model.ErrorWithStatusCode {
|
||||||
|
var imageResponse CreateImageResponse
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal(respBody, &imageResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to OpenAI format
|
||||||
|
openaiResp := openai.ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prediction := range imageResponse.Predictions {
|
||||||
|
openaiResp.Data = append(openaiResp.Data, openai.ImageData{
|
||||||
|
B64Json: prediction.BytesBase64Encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, err := json.Marshal(openaiResp)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
_, err = c.Writer.Write(respBytes)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "vertex_ai_imagen"
|
||||||
|
}
|
||||||
|
|
||||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||||
return CreateImageRequest{
|
return CreateImageRequest{
|
||||||
Instances: []createImageInstance{
|
Instances: []createImageInstance{
|
||||||
@ -57,55 +142,3 @@ func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
|||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, openai.ErrorWrapper(
|
|
||||||
errors.Wrap(err, "failed to read response body"),
|
|
||||||
"read_response_body",
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, openai.ErrorWrapper(
|
|
||||||
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
|
|
||||||
"upstream_response",
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
imagenResp := new(CreateImageResponse)
|
|
||||||
if err := json.Unmarshal(respBody, imagenResp); err != nil {
|
|
||||||
return nil, openai.ErrorWrapper(
|
|
||||||
errors.Wrap(err, "failed to decode response body"),
|
|
||||||
"unmarshal_upstream_response",
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(imagenResp.Predictions) == 0 {
|
|
||||||
return nil, openai.ErrorWrapper(
|
|
||||||
errors.New("empty predictions"),
|
|
||||||
"empty_predictions",
|
|
||||||
http.StatusInternalServerError,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
oaiResp := openai.ImageResponse{
|
|
||||||
Created: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
for _, prediction := range imagenResp.Predictions {
|
|
||||||
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
|
|
||||||
B64Json: prediction.BytesBase64Encoded,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, oaiResp)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
168
relay/adaptor/vertexai/imagen/image.go
Normal file
168
relay/adaptor/vertexai/imagen/image.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package imagen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertImageEditRequest handles the conversion from multipart form to Imagen request format
|
||||||
|
func ConvertMultipartImageEditRequest(c *gin.Context) (*CreateImageRequest, error) {
|
||||||
|
// Recover request body for binding
|
||||||
|
requestBody, err := common.GetRequestBody(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "get request body")
|
||||||
|
}
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
|
// Parse the form
|
||||||
|
var rawReq model.OpenaiImageEditRequest
|
||||||
|
if err := c.ShouldBind(&rawReq); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parse image edit form")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate response format
|
||||||
|
if rawReq.ResponseFormat != "b64_json" {
|
||||||
|
return nil, errors.New("response_format must be b64_json for Imagen models")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default N if not provided
|
||||||
|
if rawReq.N <= 0 {
|
||||||
|
rawReq.N = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default edit mode if not provided
|
||||||
|
editMode := "EDIT_MODE_INPAINT_INSERTION"
|
||||||
|
if rawReq.EditMode != nil {
|
||||||
|
editMode = *rawReq.EditMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default mask mode if not provided
|
||||||
|
maskMode := "MASK_MODE_USER_PROVIDED"
|
||||||
|
if rawReq.MaskMode != nil {
|
||||||
|
maskMode = *rawReq.MaskMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the image file
|
||||||
|
imgFile, err := rawReq.Image.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "open image file")
|
||||||
|
}
|
||||||
|
defer imgFile.Close()
|
||||||
|
|
||||||
|
imgData, err := io.ReadAll(imgFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "read image file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the mask file
|
||||||
|
maskFile, err := rawReq.Mask.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "open mask file")
|
||||||
|
}
|
||||||
|
defer maskFile.Close()
|
||||||
|
|
||||||
|
maskData, err := io.ReadAll(maskFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "read mask file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to base64
|
||||||
|
imgBase64 := base64.StdEncoding.EncodeToString(imgData)
|
||||||
|
maskBase64 := base64.StdEncoding.EncodeToString(maskData)
|
||||||
|
|
||||||
|
// Create the request
|
||||||
|
req := &CreateImageRequest{
|
||||||
|
Instances: []createImageInstance{
|
||||||
|
{
|
||||||
|
Prompt: rawReq.Prompt,
|
||||||
|
ReferenceImages: []ReferenceImage{
|
||||||
|
{
|
||||||
|
ReferenceType: "REFERENCE_TYPE_RAW",
|
||||||
|
ReferenceId: 1,
|
||||||
|
ReferenceImage: ReferenceImageData{
|
||||||
|
BytesBase64Encoded: imgBase64,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ReferenceType: "REFERENCE_TYPE_MASK",
|
||||||
|
ReferenceId: 2,
|
||||||
|
ReferenceImage: ReferenceImageData{
|
||||||
|
BytesBase64Encoded: maskBase64,
|
||||||
|
},
|
||||||
|
MaskImageConfig: &MaskImageConfig{
|
||||||
|
MaskMode: maskMode,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Parameters: createImageParameters{
|
||||||
|
SampleCount: rawReq.N,
|
||||||
|
EditMode: &editMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleImageEdit processes an image edit response from Imagen API
|
||||||
|
func HandleImageEdit(c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) {
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var imageResponse CreateImageResponse
|
||||||
|
err = json.Unmarshal(respBody, &imageResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to OpenAI format
|
||||||
|
openaiResp := openai.ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prediction := range imageResponse.Predictions {
|
||||||
|
openaiResp.Data = append(openaiResp.Data, openai.ImageData{
|
||||||
|
B64Json: prediction.BytesBase64Encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, err := json.Marshal(openaiResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
_, err = c.Writer.Write(respBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create usage data (minimal as this API doesn't provide token counts)
|
||||||
|
usage := &model.Usage{
|
||||||
|
PromptTokens: 0,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
@ -1,18 +1,80 @@
|
|||||||
package imagen
|
package imagen
|
||||||
|
|
||||||
|
// CreateImageRequest is the request body for the Imagen API.
|
||||||
|
//
|
||||||
|
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api
|
||||||
type CreateImageRequest struct {
|
type CreateImageRequest struct {
|
||||||
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
|
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
|
||||||
Parameters createImageParameters `json:"parameters" binding:"required"`
|
Parameters createImageParameters `json:"parameters" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type createImageInstance struct {
|
type createImageInstance struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
ReferenceImages []ReferenceImage `json:"referenceImages,omitempty"`
|
||||||
|
Image *promptImage `json:"image,omitempty"` // Keeping for backward compatibility
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReferenceImage represents a reference image for the Imagen edit API
|
||||||
|
//
|
||||||
|
// https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagen-3.0-capability-001?project=ai-ca-447000
|
||||||
|
type ReferenceImage struct {
|
||||||
|
ReferenceType string `json:"referenceType" binding:"required,oneof=REFERENCE_TYPE_RAW REFERENCE_TYPE_MASK"`
|
||||||
|
ReferenceId int `json:"referenceId"`
|
||||||
|
ReferenceImage ReferenceImageData `json:"referenceImage"`
|
||||||
|
// MaskImageConfig is used when ReferenceType is "REFERENCE_TYPE_MASK",
|
||||||
|
// to provide a mask image for the reference image.
|
||||||
|
MaskImageConfig *MaskImageConfig `json:"maskImageConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReferenceImageData contains the actual image data
|
||||||
|
type ReferenceImageData struct {
|
||||||
|
BytesBase64Encoded string `json:"bytesBase64Encoded,omitempty"`
|
||||||
|
GcsUri *string `json:"gcsUri,omitempty"`
|
||||||
|
MimeType *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskImageConfig specifies how to use the mask image
|
||||||
|
type MaskImageConfig struct {
|
||||||
|
// MaskMode is used to mask mode for mask editing.
|
||||||
|
// Set MASK_MODE_USER_PROVIDED for input user provided mask in the B64_MASK_IMAGE,
|
||||||
|
// MASK_MODE_BACKGROUND for automatically mask out background without user provided mask,
|
||||||
|
// MASK_MODE_SEMANTIC for automatically generating semantic object masks by
|
||||||
|
// specifying a list of object class IDs in maskClasses.
|
||||||
|
MaskMode string `json:"maskMode" binding:"required,oneof=MASK_MODE_USER_PROVIDED MASK_MODE_BACKGROUND MASK_MODE_SEMANTIC"`
|
||||||
|
MaskClasses []int `json:"maskClasses,omitempty"` // Object class IDs when maskMode is MASK_MODE_SEMANTIC
|
||||||
|
Dilation *float64 `json:"dilation,omitempty"` // Determines the dilation percentage of the mask provided. Min: 0, Max: 1, Default: 0.03
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptImage is the image to be used as a prompt for the Imagen API.
|
||||||
|
// It can be either a base64 encoded image or a GCS URI.
|
||||||
|
type promptImage struct {
|
||||||
|
BytesBase64Encoded *string `json:"bytesBase64Encoded,omitempty"`
|
||||||
|
GcsUri *string `json:"gcsUri,omitempty"`
|
||||||
|
MimeType *string `json:"mimeType,omitempty" binding:"omitempty,oneof=image/jpeg image/png"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type createImageParameters struct {
|
type createImageParameters struct {
|
||||||
SampleCount int `json:"sample_count" binding:"required,min=1"`
|
SampleCount int `json:"sampleCount" binding:"required,min=1"`
|
||||||
|
Mode *string `json:"mode,omitempty" binding:"omitempty,oneof=upscaled"`
|
||||||
|
// EditMode set edit mode for mask editing.
|
||||||
|
// Set EDIT_MODE_INPAINT_REMOVAL for inpainting removal,
|
||||||
|
// EDIT_MODE_INPAINT_INSERTION for inpainting insert,
|
||||||
|
// EDIT_MODE_OUTPAINT for outpainting,
|
||||||
|
// EDIT_MODE_BGSWAP for background swap.
|
||||||
|
EditMode *string `json:"editMode,omitempty" binding:"omitempty,oneof=EDIT_MODE_INPAINT_REMOVAL EDIT_MODE_INPAINT_INSERTION EDIT_MODE_OUTPAINT EDIT_MODE_BGSWAP"`
|
||||||
|
UpscaleConfig *upscaleConfig `json:"upscaleConfig,omitempty"`
|
||||||
|
Seed *int64 `json:"seed,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type upscaleConfig struct {
|
||||||
|
// UpscaleFactor is the factor to which the image will be upscaled.
|
||||||
|
// If not specified, the upscale factor will be determined from
|
||||||
|
// the longer side of the input image and sampleImageSize.
|
||||||
|
// Available values: x2 or x4 .
|
||||||
|
UpscaleFactor *string `json:"upscaleFactor,omitempty" binding:"omitempty,oneof=2x 4x"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateImageResponse is the response body for the Imagen API.
|
||||||
type CreateImageResponse struct {
|
type CreateImageResponse struct {
|
||||||
Predictions []createImageResponsePrediction `json:"predictions"`
|
Predictions []createImageResponsePrediction `json:"predictions"`
|
||||||
}
|
}
|
||||||
@ -21,3 +83,10 @@ type createImageResponsePrediction struct {
|
|||||||
MimeType string `json:"mimeType"`
|
MimeType string `json:"mimeType"`
|
||||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VQARequest is the response body for the Visual Question Answering API.
|
||||||
|
//
|
||||||
|
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/visual-question-answering
|
||||||
|
type VQAResponse struct {
|
||||||
|
Predictions []string `json:"predictions"`
|
||||||
|
}
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/adaptor/openrouter"
|
import (
|
||||||
|
"mime/multipart"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openrouter"
|
||||||
|
)
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
@ -142,3 +146,19 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
|
|||||||
}
|
}
|
||||||
return input
|
return input
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenaiImageEditRequest is the request body for the OpenAI image edit API.
|
||||||
|
type OpenaiImageEditRequest struct {
|
||||||
|
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
|
||||||
|
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
||||||
|
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
|
||||||
|
Model string `json:"model" form:"model" binding:"required"`
|
||||||
|
N int `json:"n" form:"n" binding:"min=0,max=10"`
|
||||||
|
Size string `json:"size" form:"size"`
|
||||||
|
ResponseFormat string `json:"response_format" form:"response_format"`
|
||||||
|
// -------------------------------------
|
||||||
|
// Imagen-3
|
||||||
|
// -------------------------------------
|
||||||
|
EditMode *string `json:"edit_mode" form:"edit_mode"`
|
||||||
|
MaskMode *string `json:"mask_mode" form:"mask_mode"`
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user