mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
test: enhance OpenaiImageEditRequest test with image and mask creation
This commit is contained in:
parent
ab69bca2d1
commit
d2bc9eb5ae
@ -2,7 +2,9 @@ package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"image"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@ -11,61 +13,57 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToFluxRemixRequest(t *testing.T) {
|
||||
// Prepare input data
|
||||
imageData := []byte{0x89, 0x50, 0x4E, 0x47} // Simulates PNG magic bytes
|
||||
maskData := []byte{
|
||||
0, 0, 0, 0, // Transparent pixel
|
||||
255, 255, 255, 255, // Opaque white pixel
|
||||
}
|
||||
prompt := "Test prompt"
|
||||
model := "Test model"
|
||||
responseType := "json"
|
||||
type nopCloser struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
// convert image and mask to FileHeader
|
||||
imageFileHeader, err := createFileHeader("image", "image.png", imageData)
|
||||
require.NoError(t, err)
|
||||
maskFileHeader, err := createFileHeader("mask", "mask.png", maskData)
|
||||
func (n nopCloser) Close() error { return nil }
|
||||
|
||||
// Custom FileHeader to override Open method
|
||||
type customFileHeader struct {
|
||||
*multipart.FileHeader
|
||||
openFunc func() (multipart.File, error)
|
||||
}
|
||||
|
||||
func (c *customFileHeader) Open() (multipart.File, error) {
|
||||
return c.openFunc()
|
||||
}
|
||||
|
||||
func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
|
||||
// Create a simple image for testing
|
||||
img := image.NewRGBA(image.Rect(0, 0, 10, 10))
|
||||
draw.Draw(img, img.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src)
|
||||
var imgBuf bytes.Buffer
|
||||
err := png.Encode(&imgBuf, img)
|
||||
require.NoError(t, err)
|
||||
|
||||
request := OpenaiImageEditRequest{
|
||||
Image: imageFileHeader,
|
||||
// Create a simple mask for testing
|
||||
mask := image.NewRGBA(image.Rect(0, 0, 10, 10))
|
||||
draw.Draw(mask, mask.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src)
|
||||
var maskBuf bytes.Buffer
|
||||
err = png.Encode(&maskBuf, mask)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a multipart.FileHeader from the image and mask bytes
|
||||
imgFileHeader, err := createFileHeader("image", "test.png", imgBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &OpenaiImageEditRequest{
|
||||
Image: imgFileHeader,
|
||||
Mask: maskFileHeader,
|
||||
Prompt: prompt,
|
||||
Model: model,
|
||||
ResponseFormat: responseType,
|
||||
Prompt: "Test prompt",
|
||||
Model: "test-model",
|
||||
ResponseFormat: "b64_json",
|
||||
}
|
||||
|
||||
// Call the method under test
|
||||
fluxRequest, err := request.toFluxRemixRequest()
|
||||
fluxReq, err := req.toFluxRemixRequest()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify FluxInpaintingInput fields
|
||||
require.NotNil(t, fluxRequest)
|
||||
require.Equal(t, prompt, fluxRequest.Input.Prompt)
|
||||
require.Equal(t, 30, fluxRequest.Input.Steps)
|
||||
require.Equal(t, 3, fluxRequest.Input.Guidance)
|
||||
require.Equal(t, 5, fluxRequest.Input.SafetyTolerance)
|
||||
require.False(t, fluxRequest.Input.PromptUnsampling)
|
||||
|
||||
// Check image field (Base64 encoded)
|
||||
expectedImageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||
require.Equal(t, expectedImageBase64, fluxRequest.Input.Image)
|
||||
|
||||
// Check mask field (Base64 encoded and inverted transparency)
|
||||
expectedInvertedMask := []byte{
|
||||
255, 255, 255, 255, // Transparent pixel inverted to black
|
||||
255, 255, 255, 255, // Opaque white pixel remains the same
|
||||
}
|
||||
expectedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(expectedInvertedMask)
|
||||
require.Equal(t, expectedMaskBase64, fluxRequest.Input.Mask)
|
||||
|
||||
// Verify seed
|
||||
// Since the seed is generated based on the current time, we validate its presence
|
||||
require.NotZero(t, fluxRequest.Input.Seed)
|
||||
require.True(t, fluxRequest.Input.Seed > 0)
|
||||
|
||||
// Additional assertions can be added as necessary
|
||||
require.NotNil(t, fluxReq)
|
||||
require.Equal(t, req.Prompt, fluxReq.Input.Prompt)
|
||||
require.NotEmpty(t, fluxReq.Input.Image)
|
||||
require.NotEmpty(t, fluxReq.Input.Mask)
|
||||
}
|
||||
|
||||
// createFileHeader creates a multipart.FileHeader from file bytes
|
||||
|
Loading…
Reference in New Issue
Block a user