From f72c715e4cb1321daf43596b0be567fc1b48ba1e Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 16 Dec 2024 09:12:24 +0000 Subject: [PATCH] feat: support image inpainting for flux-fill on replicate --- common/gin.go | 28 +++--- controller/relay.go | 11 +-- go.mod | 2 +- middleware/utils.go | 20 +++-- relay/adaptor/replicate/adaptor.go | 47 +++++++++- relay/adaptor/replicate/image.go | 13 +-- relay/adaptor/replicate/model.go | 118 ++++++++++++++++++++++++++ relay/adaptor/replicate/model_test.go | 107 +++++++++++++++++++++++ relay/controller/image.go | 21 +++-- relay/model/image.go | 2 +- 10 files changed, 330 insertions(+), 39 deletions(-) create mode 100644 relay/adaptor/replicate/model_test.go diff --git a/common/gin.go b/common/gin.go index 40095418..36da8e4c 100644 --- a/common/gin.go +++ b/common/gin.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "io" + "reflect" "strings" "github.com/gin-gonic/gin" @@ -11,18 +12,18 @@ import ( "github.com/songquanpeng/one-api/common/ctxkey" ) -func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(ctxkey.KeyRequestBody) - if requestBody != nil { - return requestBody.([]byte), nil +func GetRequestBody(c *gin.Context) (requestBody []byte, err error) { + if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil { + return requestBodyCache.([]byte), nil } - requestBody, err := io.ReadAll(c.Request.Body) + requestBody, err = io.ReadAll(c.Request.Body) if err != nil { return nil, errors.Wrap(err, "read request body failed") } _ = c.Request.Body.Close() c.Set(ctxkey.KeyRequestBody, requestBody) - return requestBody.([]byte), nil + + return requestBody, nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { @@ -30,19 +31,26 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return errors.Wrap(err, "get request body failed") } + + // check v should be a pointer + if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr { + return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v)) + } + contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = json.Unmarshal(requestBody, v) } else { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - err = c.ShouldBind(&v) + err = c.ShouldBind(v) } if err != nil { - return err + return errors.Wrap(err, "unmarshal request body failed") } // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return nil } diff --git a/controller/relay.go b/controller/relay.go index e70df858..4613aca9 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -58,8 +58,7 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - // BUG: bizErr is shared, should not run this function in goroutine to avoid race - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if err := shouldRetry(c, bizErr.StatusCode); err != nil { @@ -86,8 +85,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - // BUG: bizErr is in race condition - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) } if bizErr != nil { @@ -119,7 +117,10 @@ func shouldRetry(c *gin.Context, statusCode int) error { return nil } -func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { +func processChannelRelayError(ctx context.Context, + userId int, channelId int, channelName string, + // FIX: err should not use a pointer to avoid data race in concurrent situations + err model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { diff --git a/go.mod b/go.mod index 2e346920..91924862 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.24.0 golang.org/x/image v0.18.0 + golang.org/x/sync v0.7.0 google.golang.org/api v0.187.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 @@ -113,7 +114,6 @@ require ( golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/term v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect diff --git a/middleware/utils.go b/middleware/utils.go index 2afcab47..46120f2a 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,10 +1,10 @@ package middleware import ( - "fmt" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -25,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) { var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { - return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed") } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + + switch { + case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"): if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + case strings.HasSuffix(c.Request.URL.Path, "embeddings"): if modelRequest.Model == "" { modelRequest.Model = c.Param("model") } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"), + strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"): if modelRequest.Model == "" { modelRequest.Model = "dall-e-2" } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"), + strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"): if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } } + return modelRequest.Model, nil } diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go index 7ab0c59d..34e42de5 100644 --- a/relay/adaptor/replicate/adaptor.go +++ b/relay/adaptor/replicate/adaptor.go @@ -1,6 +1,7 @@ package replicate import ( + "bytes" "fmt" "io" "net/http" @@ -9,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -22,7 +24,31 @@ type Adaptor struct { } // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("should call replicate.ConvertImageRequest instead") +} + +func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N != 1 && request.N != 0 { + return nil, errors.New("only support N=1") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + case relaymode.ImagesEdits: + return convertImageRemixRequest(c) + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { return DrawImageRequest{ Input: ImageInput{ Steps: 25, @@ -38,6 +64,22 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { }, nil } +func convertImageRemixRequest(c *gin.Context) (any, error) { + // recover request body + requestBody, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.Wrap(err, "get request body") + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + rawReq := new(OpenaiImageEditRequest) + if err := c.ShouldBind(rawReq); err != nil { + return nil, errors.Wrap(err, "parse image edit form") + } + + return rawReq.toFluxRemixRequest() +} + func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { return nil, errors.New("not implemented") } @@ -67,7 +109,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { switch meta.Mode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: err, usage = ImageHandler(c, resp) default: err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go index 3687249a..5cc093bf 100644 --- a/relay/adaptor/replicate/image.go +++ b/relay/adaptor/replicate/image.go @@ -22,9 +22,9 @@ import ( "golang.org/x/sync/errgroup" ) -// ImagesEditsHandler just copy response body to client -// -// https://replicate.com/black-forest-labs/flux-fill-pro +// // ImagesEditsHandler just copy response body to client +// // +// // https://replicate.com/black-forest-labs/flux-fill-pro // func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { // c.Writer.WriteHeader(resp.StatusCode) // for k, v := range resp.Header { @@ -32,7 +32,7 @@ import ( // } // if _, err := io.Copy(c.Writer, resp.Body); err != nil { -// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil +// return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil // } // defer resp.Body.Close() @@ -41,7 +41,8 @@ import ( var errNextLoop = errors.New("next_loop") -func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +func ImageHandler(c *gin.Context, resp *http.Response) ( + *model.ErrorWithStatusCode, *model.Usage) { if resp.StatusCode != http.StatusCreated { payload, _ := io.ReadAll(resp.Body) return openai.ErrorWrapper( @@ -95,7 +96,7 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo switch taskData.Status { case "succeeded": case "failed", "canceled": - return errors.Errorf("task failed: %s", taskData.Status) + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) default: time.Sleep(time.Second * 3) return errNextLoop diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go index 465eefc5..93bdcda2 100644 --- a/relay/adaptor/replicate/model.go +++ b/relay/adaptor/replicate/model.go @@ -1,11 +1,129 @@ package replicate import ( + "bytes" + "encoding/base64" + "image" + "image/png" + "io" + "mime/multipart" "time" "github.com/pkg/errors" ) +type OpenaiImageEditRequest struct { + Image *multipart.FileHeader `json:"image" form:"image" binding:"required"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"` + Model string `json:"model" form:"model" binding:"required"` + N int `json:"n" form:"n" binding:"min=0,max=10"` + Size string `json:"size" form:"size"` + ResponseFormat string `json:"response_format" form:"response_format"` +} + +// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request. +// +// Note that the mask formats of OpenAI and Flux are different: +// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0), +// while Flux sets the parts to be modified as black (255, 255, 255, 255), +// so we need to convert the format here. +// +// Both OpenAI's Image and Mask are browser-native ImageData, +// which need to be converted to base64 dataURI format. +func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) { + if r.ResponseFormat != "b64_json" { + return nil, errors.New("response_format must be b64_json for replicate models") + } + + fluxReq := &InpaintingImageByFlusReplicateRequest{ + Input: FluxInpaintingInput{ + Prompt: r.Prompt, + Seed: int(time.Now().UnixNano()), + Steps: 30, + Guidance: 3, + SafetyTolerance: 5, + PromptUnsampling: false, + OutputFormat: "png", + }, + } + + imgFile, err := r.Image.Open() + if err != nil { + return nil, errors.Wrap(err, "open image file") + } + defer imgFile.Close() + imgData, err := io.ReadAll(imgFile) + if err != nil { + return nil, errors.Wrap(err, "read image file") + } + + maskFile, err := r.Mask.Open() + if err != nil { + return nil, errors.Wrap(err, "open mask file") + } + defer maskFile.Close() + + // Convert image to base64 + imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData) + fluxReq.Input.Image = imageBase64 + + // Convert mask data to RGBA + maskPNG, err := png.Decode(maskFile) + if err != nil { + return nil, errors.Wrap(err, "decode mask file") + } + + // convert mask to RGBA + var maskRGBA *image.RGBA + switch converted := maskPNG.(type) { + case *image.RGBA: + maskRGBA = converted + default: + // Convert to RGBA + bounds := maskPNG.Bounds() + maskRGBA = image.NewRGBA(bounds) + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + maskRGBA.Set(x, y, maskPNG.At(x, y)) + } + } + } + + maskData := maskRGBA.Pix + invertedMask := make([]byte, len(maskData)) + for i := 0; i+4 <= len(maskData); i += 4 { + // If pixel is transparent (alpha = 0), make it black (255) + if maskData[i+3] == 0 { + invertedMask[i] = 255 // R + invertedMask[i+1] = 255 // G + invertedMask[i+2] = 255 // B + invertedMask[i+3] = 255 // A + } else { + // Copy original pixel + copy(invertedMask[i:i+4], maskData[i:i+4]) + } + } + + // Convert inverted mask to base64 encoded png image + invertedMaskRGBA := &image.RGBA{ + Pix: invertedMask, + Stride: maskRGBA.Stride, + Rect: maskRGBA.Rect, + } + + var buf bytes.Buffer + err = png.Encode(&buf, invertedMaskRGBA) + if err != nil { + return nil, errors.Wrap(err, "encode inverted mask to png") + } + + invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()) + fluxReq.Input.Mask = invertedMaskBase64 + + return fluxReq, nil +} + // DrawImageRequest draw image by fluxpro // // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json diff --git a/relay/adaptor/replicate/model_test.go b/relay/adaptor/replicate/model_test.go new file mode 100644 index 00000000..c08bcfc8 --- /dev/null +++ b/relay/adaptor/replicate/model_test.go @@ -0,0 +1,107 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "io" + "mime/multipart" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToFluxRemixRequest(t *testing.T) { + // Prepare input data + imageData := []byte{0x89, 0x50, 0x4E, 0x47} // Simulates PNG magic bytes + maskData := []byte{ + 0, 0, 0, 0, // Transparent pixel + 255, 255, 255, 255, // Opaque white pixel + } + prompt := "Test prompt" + model := "Test model" + responseType := "json" + + // convert image and mask to FileHeader + imageFileHeader, err := createFileHeader("image", "image.png", imageData) + require.NoError(t, err) + maskFileHeader, err := createFileHeader("mask", "mask.png", maskData) + require.NoError(t, err) + + request := OpenaiImageEditRequest{ + Image: imageFileHeader, + Mask: maskFileHeader, + Prompt: prompt, + Model: model, + ResponseFormat: responseType, + } + + // Call the method under test + fluxRequest := request.toFluxRemixRequest() + + // Verify FluxInpaintingInput fields + require.NotNil(t, fluxRequest) + require.Equal(t, prompt, fluxRequest.Input.Prompt) + require.Equal(t, 30, fluxRequest.Input.Steps) + require.Equal(t, 3, fluxRequest.Input.Guidance) + require.Equal(t, 5, fluxRequest.Input.SafetyTolerance) + require.False(t, fluxRequest.Input.PromptUnsampling) + + // Check image field (Base64 encoded) + expectedImageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + require.Equal(t, expectedImageBase64, fluxRequest.Input.Image) + + // Check mask field (Base64 encoded and inverted transparency) + expectedInvertedMask := []byte{ + 255, 255, 255, 255, // Transparent pixel inverted to black + 255, 255, 255, 255, // Opaque white pixel remains the same + } + expectedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(expectedInvertedMask) + require.Equal(t, expectedMaskBase64, fluxRequest.Input.Mask) + + // Verify seed + // Since the seed is generated based on the current time, we validate its presence + require.NotZero(t, fluxRequest.Input.Seed) + require.True(t, fluxRequest.Input.Seed > 0) + + // Additional assertions can be added as necessary +} + +// createFileHeader creates a multipart.FileHeader from file bytes +func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Create a form file field + part, err := writer.CreateFormFile(fieldname, filename) + if err != nil { + return nil, err + } + + // Write the file bytes to the form file field + _, err = part.Write(fileBytes) + if err != nil { + return nil, err + } + + // Close the writer to finalize the form + err = writer.Close() + if err != nil { + return nil, err + } + + // Parse the multipart form + req := &http.Request{ + Header: http.Header{}, + Body: io.NopCloser(body), + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = req.ParseMultipartForm(int64(body.Len())) + if err != nil { + return nil, err + } + + // Retrieve the file header from the parsed form + fileHeader := req.MultipartForm.File[fieldname][0] + return fileHeader, nil +} diff --git a/relay/controller/image.go b/relay/controller/image.go index d02c9552..16988491 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -5,6 +5,10 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" @@ -13,20 +17,18 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return nil, err + return nil, errors.WithStack(err) } if imageRequest.N == 0 { imageRequest.N = 1 @@ -155,7 +157,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus switch meta.ChannelType { case channeltype.Zhipu, channeltype.Ali, - channeltype.Replicate, channeltype.Baidu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { @@ -166,6 +167,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case channeltype.Replicate: + finalRequest, err := replicate.ConvertImageRequest(c, imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) diff --git a/relay/model/image.go b/relay/model/image.go index 00bd8b79..ec3e7691 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -2,7 +2,7 @@ package model type ImageRequest struct { Model string `json:"model" form:"model"` - Prompt string `json:"prompt" binding:"required" form:"prompt"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` N int `json:"n,omitempty" form:"n"` Size string `json:"size,omitempty" form:"size"` Quality string `json:"quality,omitempty" form:"quality"`