mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: support image inpainting for flux-fill on replicate
This commit is contained in:
		@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -11,18 +12,18 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
 | 
			
		||||
	requestBody, _ := c.Get(ctxkey.KeyRequestBody)
 | 
			
		||||
	if requestBody != nil {
 | 
			
		||||
		return requestBody.([]byte), nil
 | 
			
		||||
func GetRequestBody(c *gin.Context) (requestBody []byte, err error) {
 | 
			
		||||
	if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil {
 | 
			
		||||
		return requestBodyCache.([]byte), nil
 | 
			
		||||
	}
 | 
			
		||||
	requestBody, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
	requestBody, err = io.ReadAll(c.Request.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "read request body failed")
 | 
			
		||||
	}
 | 
			
		||||
	_ = c.Request.Body.Close()
 | 
			
		||||
	c.Set(ctxkey.KeyRequestBody, requestBody)
 | 
			
		||||
	return requestBody.([]byte), nil
 | 
			
		||||
 | 
			
		||||
	return requestBody, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
@@ -30,19 +31,26 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "get request body failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check v should be a pointer
 | 
			
		||||
	if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr {
 | 
			
		||||
		return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	contentType := c.Request.Header.Get("Content-Type")
 | 
			
		||||
	if strings.HasPrefix(contentType, "application/json") {
 | 
			
		||||
		err = json.Unmarshal(requestBody, &v)
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
		err = json.Unmarshal(requestBody, v)
 | 
			
		||||
	} else {
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
		err = c.ShouldBind(&v)
 | 
			
		||||
		err = c.ShouldBind(v)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errors.Wrap(err, "unmarshal request body failed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Reset request body
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -58,8 +58,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
	channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
	group := c.GetString(ctxkey.Group)
 | 
			
		||||
	originalModel := c.GetString(ctxkey.OriginalModel)
 | 
			
		||||
	// BUG: bizErr is shared, should not run this function in goroutine to avoid race
 | 
			
		||||
	go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
 | 
			
		||||
	go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
 | 
			
		||||
	requestId := c.GetString(helper.RequestIdKey)
 | 
			
		||||
	retryTimes := config.RetryTimes
 | 
			
		||||
	if err := shouldRetry(c, bizErr.StatusCode); err != nil {
 | 
			
		||||
@@ -86,8 +85,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
		lastFailedChannelId = channelId
 | 
			
		||||
		channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
		// BUG: bizErr is in race condition
 | 
			
		||||
		go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
 | 
			
		||||
		go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if bizErr != nil {
 | 
			
		||||
@@ -119,7 +117,10 @@ func shouldRetry(c *gin.Context, statusCode int) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
 | 
			
		||||
func processChannelRelayError(ctx context.Context,
 | 
			
		||||
	userId int, channelId int, channelName string,
 | 
			
		||||
	// FIX: err should not use a pointer to avoid data race in concurrent situations
 | 
			
		||||
	err model.ErrorWithStatusCode) {
 | 
			
		||||
	logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
 | 
			
		||||
	// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
	if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -31,6 +31,7 @@ require (
 | 
			
		||||
	github.com/stretchr/testify v1.9.0
 | 
			
		||||
	golang.org/x/crypto v0.24.0
 | 
			
		||||
	golang.org/x/image v0.18.0
 | 
			
		||||
	golang.org/x/sync v0.7.0
 | 
			
		||||
	google.golang.org/api v0.187.0
 | 
			
		||||
	gorm.io/driver/mysql v1.5.6
 | 
			
		||||
	gorm.io/driver/postgres v1.5.7
 | 
			
		||||
@@ -113,7 +114,6 @@ require (
 | 
			
		||||
	golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
 | 
			
		||||
	golang.org/x/net v0.26.0 // indirect
 | 
			
		||||
	golang.org/x/oauth2 v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/term v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.16.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,10 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
@@ -25,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) {
 | 
			
		||||
	var modelRequest ModelRequest
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
 | 
			
		||||
		return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed")
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
 | 
			
		||||
	switch {
 | 
			
		||||
	case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"):
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = "text-moderation-stable"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
			
		||||
	case strings.HasSuffix(c.Request.URL.Path, "embeddings"):
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = c.Param("model")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
			
		||||
	case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"),
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"):
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = "dall-e-2"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
 | 
			
		||||
	case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"),
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"):
 | 
			
		||||
		if modelRequest.Model == "" {
 | 
			
		||||
			modelRequest.Model = "whisper-1"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return modelRequest.Model, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -9,6 +10,7 @@ import (
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
@@ -22,7 +24,31 @@ type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("should call replicate.ConvertImageRequest instead")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
 | 
			
		||||
	meta := meta.GetByContext(c)
 | 
			
		||||
 | 
			
		||||
	if request.ResponseFormat != "b64_json" {
 | 
			
		||||
		return nil, errors.New("only support b64_json response format")
 | 
			
		||||
	}
 | 
			
		||||
	if request.N != 1 && request.N != 0 {
 | 
			
		||||
		return nil, errors.New("only support N=1")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		return convertImageCreateRequest(request)
 | 
			
		||||
	case relaymode.ImagesEdits:
 | 
			
		||||
		return convertImageRemixRequest(c)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.New("not implemented")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return DrawImageRequest{
 | 
			
		||||
		Input: ImageInput{
 | 
			
		||||
			Steps:           25,
 | 
			
		||||
@@ -38,6 +64,22 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertImageRemixRequest(c *gin.Context) (any, error) {
 | 
			
		||||
	// recover request body
 | 
			
		||||
	requestBody, err := common.GetRequestBody(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "get request body")
 | 
			
		||||
	}
 | 
			
		||||
	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
 | 
			
		||||
	rawReq := new(OpenaiImageEditRequest)
 | 
			
		||||
	if err := c.ShouldBind(rawReq); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "parse image edit form")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rawReq.toFluxRemixRequest()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
@@ -67,7 +109,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
	case relaymode.ImagesGenerations,
 | 
			
		||||
		relaymode.ImagesEdits:
 | 
			
		||||
		err, usage = ImageHandler(c, resp)
 | 
			
		||||
	default:
 | 
			
		||||
		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
 | 
			
		||||
 
 | 
			
		||||
@@ -22,9 +22,9 @@ import (
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ImagesEditsHandler just copy response body to client
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro
 | 
			
		||||
// // ImagesEditsHandler just copy response body to client
 | 
			
		||||
// //
 | 
			
		||||
// // https://replicate.com/black-forest-labs/flux-fill-pro
 | 
			
		||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
// 	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
// 	for k, v := range resp.Header {
 | 
			
		||||
@@ -32,7 +32,7 @@ import (
 | 
			
		||||
// 	}
 | 
			
		||||
 | 
			
		||||
// 	if _, err := io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
// 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
// 		return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
// 	}
 | 
			
		||||
// 	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
@@ -41,7 +41,8 @@ import (
 | 
			
		||||
 | 
			
		||||
var errNextLoop = errors.New("next_loop")
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (
 | 
			
		||||
	*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
@@ -95,7 +96,7 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed: %s", taskData.Status)
 | 
			
		||||
				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,129 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/png"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OpenaiImageEditRequest struct {
 | 
			
		||||
	Image          *multipart.FileHeader `json:"image" form:"image" binding:"required"`
 | 
			
		||||
	Prompt         string                `json:"prompt" form:"prompt" binding:"required"`
 | 
			
		||||
	Mask           *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
 | 
			
		||||
	Model          string                `json:"model" form:"model" binding:"required"`
 | 
			
		||||
	N              int                   `json:"n" form:"n" binding:"min=0,max=10"`
 | 
			
		||||
	Size           string                `json:"size" form:"size"`
 | 
			
		||||
	ResponseFormat string                `json:"response_format" form:"response_format"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
 | 
			
		||||
//
 | 
			
		||||
// Note that the mask formats of OpenAI and Flux are different:
 | 
			
		||||
// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0),
 | 
			
		||||
// while Flux sets the parts to be modified as black (255, 255, 255, 255),
 | 
			
		||||
// so we need to convert the format here.
 | 
			
		||||
//
 | 
			
		||||
// Both OpenAI's Image and Mask are browser-native ImageData,
 | 
			
		||||
// which need to be converted to base64 dataURI format.
 | 
			
		||||
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
 | 
			
		||||
	if r.ResponseFormat != "b64_json" {
 | 
			
		||||
		return nil, errors.New("response_format must be b64_json for replicate models")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fluxReq := &InpaintingImageByFlusReplicateRequest{
 | 
			
		||||
		Input: FluxInpaintingInput{
 | 
			
		||||
			Prompt:           r.Prompt,
 | 
			
		||||
			Seed:             int(time.Now().UnixNano()),
 | 
			
		||||
			Steps:            30,
 | 
			
		||||
			Guidance:         3,
 | 
			
		||||
			SafetyTolerance:  5,
 | 
			
		||||
			PromptUnsampling: false,
 | 
			
		||||
			OutputFormat:     "png",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	imgFile, err := r.Image.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "open image file")
 | 
			
		||||
	}
 | 
			
		||||
	defer imgFile.Close()
 | 
			
		||||
	imgData, err := io.ReadAll(imgFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "read image file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	maskFile, err := r.Mask.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "open mask file")
 | 
			
		||||
	}
 | 
			
		||||
	defer maskFile.Close()
 | 
			
		||||
 | 
			
		||||
	// Convert image to base64
 | 
			
		||||
	imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData)
 | 
			
		||||
	fluxReq.Input.Image = imageBase64
 | 
			
		||||
 | 
			
		||||
	// Convert mask data to RGBA
 | 
			
		||||
	maskPNG, err := png.Decode(maskFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "decode mask file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// convert mask to RGBA
 | 
			
		||||
	var maskRGBA *image.RGBA
 | 
			
		||||
	switch converted := maskPNG.(type) {
 | 
			
		||||
	case *image.RGBA:
 | 
			
		||||
		maskRGBA = converted
 | 
			
		||||
	default:
 | 
			
		||||
		// Convert to RGBA
 | 
			
		||||
		bounds := maskPNG.Bounds()
 | 
			
		||||
		maskRGBA = image.NewRGBA(bounds)
 | 
			
		||||
		for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
 | 
			
		||||
			for x := bounds.Min.X; x < bounds.Max.X; x++ {
 | 
			
		||||
				maskRGBA.Set(x, y, maskPNG.At(x, y))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	maskData := maskRGBA.Pix
 | 
			
		||||
	invertedMask := make([]byte, len(maskData))
 | 
			
		||||
	for i := 0; i+4 <= len(maskData); i += 4 {
 | 
			
		||||
		// If pixel is transparent (alpha = 0), make it black (255)
 | 
			
		||||
		if maskData[i+3] == 0 {
 | 
			
		||||
			invertedMask[i] = 255   // R
 | 
			
		||||
			invertedMask[i+1] = 255 // G
 | 
			
		||||
			invertedMask[i+2] = 255 // B
 | 
			
		||||
			invertedMask[i+3] = 255 // A
 | 
			
		||||
		} else {
 | 
			
		||||
			// Copy original pixel
 | 
			
		||||
			copy(invertedMask[i:i+4], maskData[i:i+4])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Convert inverted mask to base64 encoded png image
 | 
			
		||||
	invertedMaskRGBA := &image.RGBA{
 | 
			
		||||
		Pix:    invertedMask,
 | 
			
		||||
		Stride: maskRGBA.Stride,
 | 
			
		||||
		Rect:   maskRGBA.Rect,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var buf bytes.Buffer
 | 
			
		||||
	err = png.Encode(&buf, invertedMaskRGBA)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "encode inverted mask to png")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes())
 | 
			
		||||
	fluxReq.Input.Mask = invertedMaskBase64
 | 
			
		||||
 | 
			
		||||
	return fluxReq, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DrawImageRequest draw image by fluxpro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										107
									
								
								relay/adaptor/replicate/model_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								relay/adaptor/replicate/model_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,107 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestToFluxRemixRequest(t *testing.T) {
 | 
			
		||||
	// Prepare input data
 | 
			
		||||
	imageData := []byte{0x89, 0x50, 0x4E, 0x47} // Simulates PNG magic bytes
 | 
			
		||||
	maskData := []byte{
 | 
			
		||||
		0, 0, 0, 0, // Transparent pixel
 | 
			
		||||
		255, 255, 255, 255, // Opaque white pixel
 | 
			
		||||
	}
 | 
			
		||||
	prompt := "Test prompt"
 | 
			
		||||
	model := "Test model"
 | 
			
		||||
	responseType := "json"
 | 
			
		||||
 | 
			
		||||
	// convert image and mask to FileHeader
 | 
			
		||||
	imageFileHeader, err := createFileHeader("image", "image.png", imageData)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	maskFileHeader, err := createFileHeader("mask", "mask.png", maskData)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	request := OpenaiImageEditRequest{
 | 
			
		||||
		Image:          imageFileHeader,
 | 
			
		||||
		Mask:           maskFileHeader,
 | 
			
		||||
		Prompt:         prompt,
 | 
			
		||||
		Model:          model,
 | 
			
		||||
		ResponseFormat: responseType,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Call the method under test
 | 
			
		||||
	fluxRequest := request.toFluxRemixRequest()
 | 
			
		||||
 | 
			
		||||
	// Verify FluxInpaintingInput fields
 | 
			
		||||
	require.NotNil(t, fluxRequest)
 | 
			
		||||
	require.Equal(t, prompt, fluxRequest.Input.Prompt)
 | 
			
		||||
	require.Equal(t, 30, fluxRequest.Input.Steps)
 | 
			
		||||
	require.Equal(t, 3, fluxRequest.Input.Guidance)
 | 
			
		||||
	require.Equal(t, 5, fluxRequest.Input.SafetyTolerance)
 | 
			
		||||
	require.False(t, fluxRequest.Input.PromptUnsampling)
 | 
			
		||||
 | 
			
		||||
	// Check image field (Base64 encoded)
 | 
			
		||||
	expectedImageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	require.Equal(t, expectedImageBase64, fluxRequest.Input.Image)
 | 
			
		||||
 | 
			
		||||
	// Check mask field (Base64 encoded and inverted transparency)
 | 
			
		||||
	expectedInvertedMask := []byte{
 | 
			
		||||
		255, 255, 255, 255, // Transparent pixel inverted to black
 | 
			
		||||
		255, 255, 255, 255, // Opaque white pixel remains the same
 | 
			
		||||
	}
 | 
			
		||||
	expectedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(expectedInvertedMask)
 | 
			
		||||
	require.Equal(t, expectedMaskBase64, fluxRequest.Input.Mask)
 | 
			
		||||
 | 
			
		||||
	// Verify seed
 | 
			
		||||
	// Since the seed is generated based on the current time, we validate its presence
 | 
			
		||||
	require.NotZero(t, fluxRequest.Input.Seed)
 | 
			
		||||
	require.True(t, fluxRequest.Input.Seed > 0)
 | 
			
		||||
 | 
			
		||||
	// Additional assertions can be added as necessary
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// createFileHeader creates a multipart.FileHeader from file bytes
 | 
			
		||||
func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) {
 | 
			
		||||
	body := &bytes.Buffer{}
 | 
			
		||||
	writer := multipart.NewWriter(body)
 | 
			
		||||
 | 
			
		||||
	// Create a form file field
 | 
			
		||||
	part, err := writer.CreateFormFile(fieldname, filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Write the file bytes to the form file field
 | 
			
		||||
	_, err = part.Write(fileBytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Close the writer to finalize the form
 | 
			
		||||
	err = writer.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Parse the multipart form
 | 
			
		||||
	req := &http.Request{
 | 
			
		||||
		Header: http.Header{},
 | 
			
		||||
		Body:   io.NopCloser(body),
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", writer.FormDataContentType())
 | 
			
		||||
	err = req.ParseMultipartForm(int64(body.Len()))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Retrieve the file header from the parsed form
 | 
			
		||||
	fileHeader := req.MultipartForm.File[fieldname][0]
 | 
			
		||||
	return fileHeader, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -5,6 +5,10 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/Laisky/errors/v2"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
@@ -13,20 +17,18 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/replicate"
 | 
			
		||||
	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return nil, errors.WithStack(err)
 | 
			
		||||
	}
 | 
			
		||||
	if imageRequest.N == 0 {
 | 
			
		||||
		imageRequest.N = 1
 | 
			
		||||
@@ -155,7 +157,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Zhipu,
 | 
			
		||||
		channeltype.Ali,
 | 
			
		||||
		channeltype.Replicate,
 | 
			
		||||
		channeltype.Baidu:
 | 
			
		||||
		finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -166,6 +167,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case channeltype.Replicate:
 | 
			
		||||
		finalRequest, err := replicate.ConvertImageRequest(c, imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		jsonStr, err := json.Marshal(finalRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ package model
 | 
			
		||||
 | 
			
		||||
type ImageRequest struct {
 | 
			
		||||
	Model          string `json:"model" form:"model"`
 | 
			
		||||
	Prompt         string `json:"prompt" binding:"required" form:"prompt"`
 | 
			
		||||
	Prompt         string `json:"prompt" form:"prompt" binding:"required"`
 | 
			
		||||
	N              int    `json:"n,omitempty" form:"n"`
 | 
			
		||||
	Size           string `json:"size,omitempty" form:"size"`
 | 
			
		||||
	Quality        string `json:"quality,omitempty" form:"quality"`
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user