From ab69bca2d1cab0edd011e7bb456c3070b35a9bfb Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 16 Dec 2024 09:12:24 +0000 Subject: [PATCH] feat: support image inpainting for flux-fill on replicate --- common/gin.go | 29 ++++-- controller/relay.go | 12 +-- go.mod | 2 +- middleware/utils.go | 20 ++-- monitor/manage.go | 2 +- relay/adaptor/ollama/main.go | 6 +- relay/adaptor/replicate/adaptor.go | 47 ++++++++- relay/adaptor/replicate/image.go | 13 +-- relay/adaptor/replicate/model.go | 118 +++++++++++++++++++++++ relay/adaptor/replicate/model_test.go | 108 +++++++++++++++++++++ relay/adaptor/vertexai/gemini/adapter.go | 2 +- relay/controller/image.go | 16 ++- relay/model/image.go | 2 +- 13 files changed, 334 insertions(+), 43 deletions(-) create mode 100644 relay/adaptor/replicate/model_test.go diff --git a/common/gin.go b/common/gin.go index 10fd7c4e..c8254bfd 100644 --- a/common/gin.go +++ b/common/gin.go @@ -4,24 +4,26 @@ import ( "bytes" "encoding/json" "io" + "reflect" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/ctxkey" ) -func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(ctxkey.KeyRequestBody) - if requestBody != nil { - return requestBody.([]byte), nil +func GetRequestBody(c *gin.Context) (requestBody []byte, err error) { + if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil { + return requestBodyCache.([]byte), nil } - requestBody, err := io.ReadAll(c.Request.Body) + requestBody, err = io.ReadAll(c.Request.Body) if err != nil { return nil, err } _ = c.Request.Body.Close() c.Set(ctxkey.KeyRequestBody, requestBody) - return requestBody.([]byte), nil + + return requestBody, nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { @@ -29,18 +31,25 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } + + // check v should be a pointer + if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr { + return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v)) + } + contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = json.Unmarshal(requestBody, v) } else { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - err = c.ShouldBind(&v) + err = c.ShouldBind(v) } if err != nil { - return err + return errors.Wrap(err, "unmarshal request body failed") } // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return nil } diff --git a/controller/relay.go b/controller/relay.go index 7de84e3f..f792b258 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -57,9 +57,7 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - - // BUG: bizErr is shared, should not run this function in goroutine to avoid race - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { @@ -86,8 +84,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - // BUG: bizErr is in race condition - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) } if bizErr != nil { @@ -122,7 +119,10 @@ func shouldRetry(c *gin.Context, statusCode int) bool { return true } -func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { +func processChannelRelayError(ctx context.Context, + userId int, channelId int, channelName string, + // FIX: err should not use a pointer to avoid data race in concurrent situations + err model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { diff --git a/go.mod b/go.mod index ada53bc3..cfc8bcad 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.24.0 golang.org/x/image v0.18.0 + golang.org/x/sync v0.7.0 google.golang.org/api v0.187.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 @@ -99,7 +100,6 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/middleware/utils.go b/middleware/utils.go index 2afcab47..46120f2a 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,10 +1,10 @@ package middleware import ( - "fmt" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -25,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) { var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { - return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed") } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + + switch { + case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"): if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + case strings.HasSuffix(c.Request.URL.Path, "embeddings"): if modelRequest.Model == "" { modelRequest.Model = c.Param("model") } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"), + strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"): if modelRequest.Model == "" { modelRequest.Model = "dall-e-2" } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"), + strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"): if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } } + return modelRequest.Model, nil } diff --git a/monitor/manage.go b/monitor/manage.go index 44c13612..268d3924 100644 --- a/monitor/manage.go +++ b/monitor/manage.go @@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { strings.Contains(lowerMessage, "credit") || strings.Contains(lowerMessage, "balance") || strings.Contains(lowerMessage, "permission denied") || - strings.Contains(lowerMessage, "organization has been restricted") || // groq + strings.Contains(lowerMessage, "organization has been restricted") || // groq strings.Contains(lowerMessage, "已欠费") { return true } diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index 43317ff6..fa1b05f0 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { TopP: request.TopP, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, - NumPredict: request.MaxTokens, - NumCtx: request.NumCtx, + NumPredict: request.MaxTokens, + NumCtx: request.NumCtx, }, Stream: request.Stream, } @@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC for scanner.Scan() { data := scanner.Text() if strings.HasPrefix(data, "}") { - data = strings.TrimPrefix(data, "}") + "}" + data = strings.TrimPrefix(data, "}") + "}" } var ollamaResponse ChatResponse diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go index 7ab0c59d..34e42de5 100644 --- a/relay/adaptor/replicate/adaptor.go +++ b/relay/adaptor/replicate/adaptor.go @@ -1,6 +1,7 @@ package replicate import ( + "bytes" "fmt" "io" "net/http" @@ -9,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -22,7 +24,31 @@ type Adaptor struct { } // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("should call replicate.ConvertImageRequest instead") +} + +func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N != 1 && request.N != 0 { + return nil, errors.New("only support N=1") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + case relaymode.ImagesEdits: + return convertImageRemixRequest(c) + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { return DrawImageRequest{ Input: ImageInput{ Steps: 25, @@ -38,6 +64,22 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { }, nil } +func convertImageRemixRequest(c *gin.Context) (any, error) { + // recover request body + requestBody, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.Wrap(err, "get request body") + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + rawReq := new(OpenaiImageEditRequest) + if err := c.ShouldBind(rawReq); err != nil { + return nil, errors.Wrap(err, "parse image edit form") + } + + return rawReq.toFluxRemixRequest() +} + func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { return nil, errors.New("not implemented") } @@ -67,7 +109,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { switch meta.Mode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: err, usage = ImageHandler(c, resp) default: err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go index 3687249a..5cc093bf 100644 --- a/relay/adaptor/replicate/image.go +++ b/relay/adaptor/replicate/image.go @@ -22,9 +22,9 @@ import ( "golang.org/x/sync/errgroup" ) -// ImagesEditsHandler just copy response body to client -// -// https://replicate.com/black-forest-labs/flux-fill-pro +// // ImagesEditsHandler just copy response body to client +// // +// // https://replicate.com/black-forest-labs/flux-fill-pro // func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { // c.Writer.WriteHeader(resp.StatusCode) // for k, v := range resp.Header { @@ -32,7 +32,7 @@ import ( // } // if _, err := io.Copy(c.Writer, resp.Body); err != nil { -// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil +// return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil // } // defer resp.Body.Close() @@ -41,7 +41,8 @@ import ( var errNextLoop = errors.New("next_loop") -func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +func ImageHandler(c *gin.Context, resp *http.Response) ( + *model.ErrorWithStatusCode, *model.Usage) { if resp.StatusCode != http.StatusCreated { payload, _ := io.ReadAll(resp.Body) return openai.ErrorWrapper( @@ -95,7 +96,7 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo switch taskData.Status { case "succeeded": case "failed", "canceled": - return errors.Errorf("task failed: %s", taskData.Status) + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) default: time.Sleep(time.Second * 3) return errNextLoop diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go index 465eefc5..93bdcda2 100644 --- a/relay/adaptor/replicate/model.go +++ b/relay/adaptor/replicate/model.go @@ -1,11 +1,129 @@ package replicate import ( + "bytes" + "encoding/base64" + "image" + "image/png" + "io" + "mime/multipart" "time" "github.com/pkg/errors" ) +type OpenaiImageEditRequest struct { + Image *multipart.FileHeader `json:"image" form:"image" binding:"required"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"` + Model string `json:"model" form:"model" binding:"required"` + N int `json:"n" form:"n" binding:"min=0,max=10"` + Size string `json:"size" form:"size"` + ResponseFormat string `json:"response_format" form:"response_format"` +} + +// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request. +// +// Note that the mask formats of OpenAI and Flux are different: +// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0), +// while Flux sets the parts to be modified as black (255, 255, 255, 255), +// so we need to convert the format here. +// +// Both OpenAI's Image and Mask are browser-native ImageData, +// which need to be converted to base64 dataURI format. +func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) { + if r.ResponseFormat != "b64_json" { + return nil, errors.New("response_format must be b64_json for replicate models") + } + + fluxReq := &InpaintingImageByFlusReplicateRequest{ + Input: FluxInpaintingInput{ + Prompt: r.Prompt, + Seed: int(time.Now().UnixNano()), + Steps: 30, + Guidance: 3, + SafetyTolerance: 5, + PromptUnsampling: false, + OutputFormat: "png", + }, + } + + imgFile, err := r.Image.Open() + if err != nil { + return nil, errors.Wrap(err, "open image file") + } + defer imgFile.Close() + imgData, err := io.ReadAll(imgFile) + if err != nil { + return nil, errors.Wrap(err, "read image file") + } + + maskFile, err := r.Mask.Open() + if err != nil { + return nil, errors.Wrap(err, "open mask file") + } + defer maskFile.Close() + + // Convert image to base64 + imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData) + fluxReq.Input.Image = imageBase64 + + // Convert mask data to RGBA + maskPNG, err := png.Decode(maskFile) + if err != nil { + return nil, errors.Wrap(err, "decode mask file") + } + + // convert mask to RGBA + var maskRGBA *image.RGBA + switch converted := maskPNG.(type) { + case *image.RGBA: + maskRGBA = converted + default: + // Convert to RGBA + bounds := maskPNG.Bounds() + maskRGBA = image.NewRGBA(bounds) + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + maskRGBA.Set(x, y, maskPNG.At(x, y)) + } + } + } + + maskData := maskRGBA.Pix + invertedMask := make([]byte, len(maskData)) + for i := 0; i+4 <= len(maskData); i += 4 { + // If pixel is transparent (alpha = 0), make it black (255) + if maskData[i+3] == 0 { + invertedMask[i] = 255 // R + invertedMask[i+1] = 255 // G + invertedMask[i+2] = 255 // B + invertedMask[i+3] = 255 // A + } else { + // Copy original pixel + copy(invertedMask[i:i+4], maskData[i:i+4]) + } + } + + // Convert inverted mask to base64 encoded png image + invertedMaskRGBA := &image.RGBA{ + Pix: invertedMask, + Stride: maskRGBA.Stride, + Rect: maskRGBA.Rect, + } + + var buf bytes.Buffer + err = png.Encode(&buf, invertedMaskRGBA) + if err != nil { + return nil, errors.Wrap(err, "encode inverted mask to png") + } + + invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()) + fluxReq.Input.Mask = invertedMaskBase64 + + return fluxReq, nil +} + // DrawImageRequest draw image by fluxpro // // https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json diff --git a/relay/adaptor/replicate/model_test.go b/relay/adaptor/replicate/model_test.go new file mode 100644 index 00000000..6317d5c8 --- /dev/null +++ b/relay/adaptor/replicate/model_test.go @@ -0,0 +1,108 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "io" + "mime/multipart" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToFluxRemixRequest(t *testing.T) { + // Prepare input data + imageData := []byte{0x89, 0x50, 0x4E, 0x47} // Simulates PNG magic bytes + maskData := []byte{ + 0, 0, 0, 0, // Transparent pixel + 255, 255, 255, 255, // Opaque white pixel + } + prompt := "Test prompt" + model := "Test model" + responseType := "json" + + // convert image and mask to FileHeader + imageFileHeader, err := createFileHeader("image", "image.png", imageData) + require.NoError(t, err) + maskFileHeader, err := createFileHeader("mask", "mask.png", maskData) + require.NoError(t, err) + + request := OpenaiImageEditRequest{ + Image: imageFileHeader, + Mask: maskFileHeader, + Prompt: prompt, + Model: model, + ResponseFormat: responseType, + } + + // Call the method under test + fluxRequest, err := request.toFluxRemixRequest() + require.NoError(t, err) + + // Verify FluxInpaintingInput fields + require.NotNil(t, fluxRequest) + require.Equal(t, prompt, fluxRequest.Input.Prompt) + require.Equal(t, 30, fluxRequest.Input.Steps) + require.Equal(t, 3, fluxRequest.Input.Guidance) + require.Equal(t, 5, fluxRequest.Input.SafetyTolerance) + require.False(t, fluxRequest.Input.PromptUnsampling) + + // Check image field (Base64 encoded) + expectedImageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + require.Equal(t, expectedImageBase64, fluxRequest.Input.Image) + + // Check mask field (Base64 encoded and inverted transparency) + expectedInvertedMask := []byte{ + 255, 255, 255, 255, // Transparent pixel inverted to black + 255, 255, 255, 255, // Opaque white pixel remains the same + } + expectedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(expectedInvertedMask) + require.Equal(t, expectedMaskBase64, fluxRequest.Input.Mask) + + // Verify seed + // Since the seed is generated based on the current time, we validate its presence + require.NotZero(t, fluxRequest.Input.Seed) + require.True(t, fluxRequest.Input.Seed > 0) + + // Additional assertions can be added as necessary +} + +// createFileHeader creates a multipart.FileHeader from file bytes +func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Create a form file field + part, err := writer.CreateFormFile(fieldname, filename) + if err != nil { + return nil, err + } + + // Write the file bytes to the form file field + _, err = part.Write(fileBytes) + if err != nil { + return nil, err + } + + // Close the writer to finalize the form + err = writer.Close() + if err != nil { + return nil, err + } + + // Parse the multipart form + req := &http.Request{ + Header: http.Header{}, + Body: io.NopCloser(body), + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = req.ParseMultipartForm(int64(body.Len())) + if err != nil { + return nil, err + } + + // Retrieve the file header from the parsed form + fileHeader := req.MultipartForm.File[fieldname][0] + return fileHeader, nil +} diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index ceff1ed2..f86baee0 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -15,7 +15,7 @@ import ( ) var ModelList = []string{ - "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", } type Adaptor struct { diff --git a/relay/controller/image.go b/relay/controller/image.go index 741b9df4..6154c74c 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -4,19 +4,20 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" @@ -27,7 +28,7 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return nil, err + return nil, errors.WithStack(err) } if imageRequest.N == 0 { imageRequest.N = 1 @@ -156,7 +157,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus switch meta.ChannelType { case channeltype.Zhipu, channeltype.Ali, - channeltype.Replicate, channeltype.Baidu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { @@ -167,6 +167,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case channeltype.Replicate: + finalRequest, err := replicate.ConvertImageRequest(c, imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) diff --git a/relay/model/image.go b/relay/model/image.go index 00bd8b79..ec3e7691 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -2,7 +2,7 @@ package model type ImageRequest struct { Model string `json:"model" form:"model"` - Prompt string `json:"prompt" binding:"required" form:"prompt"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` N int `json:"n,omitempty" form:"n"` Size string `json:"size,omitempty" form:"size"` Quality string `json:"quality,omitempty" form:"quality"`