From d2bc9eb5aec01f450db5817881fe4a3562dfacb3 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Fri, 20 Dec 2024 02:47:16 +0000 Subject: [PATCH] test: enhance OpenaiImageEditRequest test with image and mask creation --- relay/adaptor/replicate/model_test.go | 96 +++++++++++++-------------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/relay/adaptor/replicate/model_test.go b/relay/adaptor/replicate/model_test.go index 6317d5c8..6cde5e94 100644 --- a/relay/adaptor/replicate/model_test.go +++ b/relay/adaptor/replicate/model_test.go @@ -2,7 +2,9 @@ package replicate import ( "bytes" - "encoding/base64" + "image" + "image/draw" + "image/png" "io" "mime/multipart" "net/http" @@ -11,61 +13,57 @@ import ( "github.com/stretchr/testify/require" ) -func TestToFluxRemixRequest(t *testing.T) { - // Prepare input data - imageData := []byte{0x89, 0x50, 0x4E, 0x47} // Simulates PNG magic bytes - maskData := []byte{ - 0, 0, 0, 0, // Transparent pixel - 255, 255, 255, 255, // Opaque white pixel - } - prompt := "Test prompt" - model := "Test model" - responseType := "json" +type nopCloser struct { + io.Reader +} - // convert image and mask to FileHeader - imageFileHeader, err := createFileHeader("image", "image.png", imageData) - require.NoError(t, err) - maskFileHeader, err := createFileHeader("mask", "mask.png", maskData) +func (n nopCloser) Close() error { return nil } + +// Custom FileHeader to override Open method +type customFileHeader struct { + *multipart.FileHeader + openFunc func() (multipart.File, error) +} + +func (c *customFileHeader) Open() (multipart.File, error) { + return c.openFunc() +} + +func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) { + // Create a simple image for testing + img := image.NewRGBA(image.Rect(0, 0, 10, 10)) + draw.Draw(img, img.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src) + var imgBuf bytes.Buffer + err := png.Encode(&imgBuf, img) require.NoError(t, err) - request := OpenaiImageEditRequest{ - Image: imageFileHeader, + // Create a simple mask for testing + mask := image.NewRGBA(image.Rect(0, 0, 10, 10)) + draw.Draw(mask, mask.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src) + var maskBuf bytes.Buffer + err = png.Encode(&maskBuf, mask) + require.NoError(t, err) + + // Create a multipart.FileHeader from the image and mask bytes + imgFileHeader, err := createFileHeader("image", "test.png", imgBuf.Bytes()) + require.NoError(t, err) + maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes()) + require.NoError(t, err) + + req := &OpenaiImageEditRequest{ + Image: imgFileHeader, Mask: maskFileHeader, - Prompt: prompt, - Model: model, - ResponseFormat: responseType, + Prompt: "Test prompt", + Model: "test-model", + ResponseFormat: "b64_json", } - // Call the method under test - fluxRequest, err := request.toFluxRemixRequest() + fluxReq, err := req.toFluxRemixRequest() require.NoError(t, err) - - // Verify FluxInpaintingInput fields - require.NotNil(t, fluxRequest) - require.Equal(t, prompt, fluxRequest.Input.Prompt) - require.Equal(t, 30, fluxRequest.Input.Steps) - require.Equal(t, 3, fluxRequest.Input.Guidance) - require.Equal(t, 5, fluxRequest.Input.SafetyTolerance) - require.False(t, fluxRequest.Input.PromptUnsampling) - - // Check image field (Base64 encoded) - expectedImageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) - require.Equal(t, expectedImageBase64, fluxRequest.Input.Image) - - // Check mask field (Base64 encoded and inverted transparency) - expectedInvertedMask := []byte{ - 255, 255, 255, 255, // Transparent pixel inverted to black - 255, 255, 255, 255, // Opaque white pixel remains the same - } - expectedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(expectedInvertedMask) - require.Equal(t, expectedMaskBase64, fluxRequest.Input.Mask) - - // Verify seed - // Since the seed is generated based on the current time, we validate its presence - require.NotZero(t, fluxRequest.Input.Seed) - require.True(t, fluxRequest.Input.Seed > 0) - - // Additional assertions can be added as necessary + require.NotNil(t, fluxReq) + require.Equal(t, req.Prompt, fluxReq.Input.Prompt) + require.NotEmpty(t, fluxReq.Input.Image) + require.NotEmpty(t, fluxReq.Input.Mask) } // createFileHeader creates a multipart.FileHeader from file bytes