feat: implement image editing request handling and response conversion for Imagen API

This commit is contained in:
Laisky.Cai
2025-03-16 14:21:38 +00:00
parent fa794e7bd5
commit 580fec6359
9 changed files with 363 additions and 80 deletions

View File

@@ -1,6 +1,7 @@
package imagen
import (
"bytes"
"encoding/json"
"io"
"net/http"
@@ -15,16 +16,28 @@ import (
)
var ModelList = []string{
// create
// -------------------------------------
// generate
// -------------------------------------
"imagen-3.0-generate-001", "imagen-3.0-generate-002",
"imagen-3.0-fast-generate-001",
// -------------------------------------
// edit
// "imagen-3.0-capability-001", // not supported yet
// -------------------------------------
"imagen-3.0-capability-001",
}
type Adaptor struct {
}
func (a *Adaptor) Init(meta *meta.Meta) {
// No initialization needed
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
meta := meta.GetByContext(c)
@@ -32,19 +45,91 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageReques
return nil, errors.New("only support b64_json response format")
}
if request.N <= 0 {
return nil, errors.New("n must be greater than 0")
request.N = 1 // Default to 1 if not specified
}
switch meta.Mode {
case relaymode.ImagesGenerations:
return convertImageCreateRequest(request)
case relaymode.ImagesEdits:
return nil, errors.New("not implemented")
switch c.ContentType() {
// case "application/json":
// return ConvertJsonImageEditRequest(c)
case "multipart/form-data":
return ConvertMultipartImageEditRequest(c)
default:
return nil, errors.New("unsupported content type for image edit")
}
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
switch meta.Mode {
case relaymode.ImagesEdits:
return HandleImageEdit(c, resp)
case relaymode.ImagesGenerations:
return nil, handleImageGeneration(c, resp, respBody)
default:
return nil, openai.ErrorWrapper(errors.New("unsupported mode"), "unsupported_mode", http.StatusBadRequest)
}
}
func handleImageGeneration(c *gin.Context, resp *http.Response, respBody []byte) *model.ErrorWithStatusCode {
var imageResponse CreateImageResponse
if resp.StatusCode != http.StatusOK {
return openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
}
err := json.Unmarshal(respBody, &imageResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// Convert to OpenAI format
openaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)),
}
for _, prediction := range imageResponse.Predictions {
openaiResp.Data = append(openaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
respBytes, err := json.Marshal(openaiResp)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(respBytes)
if err != nil {
return openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
}
return nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "vertex_ai_imagen"
}
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
return CreateImageRequest{
Instances: []createImageInstance{
@@ -57,55 +142,3 @@ func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to read response body"),
"read_response_body",
http.StatusInternalServerError,
)
}
if resp.StatusCode != http.StatusOK {
return nil, openai.ErrorWrapper(
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
"upstream_response",
http.StatusInternalServerError,
)
}
imagenResp := new(CreateImageResponse)
if err := json.Unmarshal(respBody, imagenResp); err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to decode response body"),
"unmarshal_upstream_response",
http.StatusInternalServerError,
)
}
if len(imagenResp.Predictions) == 0 {
return nil, openai.ErrorWrapper(
errors.New("empty predictions"),
"empty_predictions",
http.StatusInternalServerError,
)
}
oaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
}
for _, prediction := range imagenResp.Predictions {
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
c.JSON(http.StatusOK, oaiResp)
return nil, nil
}