mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
Merge branch 'feature/replicate-remix'
This commit is contained in:
commit
79bd053a0a
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -11,18 +12,18 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
func GetRequestBody(c *gin.Context) (requestBody []byte, err error) {
|
||||||
requestBody, _ := c.Get(ctxkey.KeyRequestBody)
|
if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil {
|
||||||
if requestBody != nil {
|
return requestBodyCache.([]byte), nil
|
||||||
return requestBody.([]byte), nil
|
|
||||||
}
|
}
|
||||||
requestBody, err := io.ReadAll(c.Request.Body)
|
requestBody, err = io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "read request body failed")
|
return nil, errors.Wrap(err, "read request body failed")
|
||||||
}
|
}
|
||||||
_ = c.Request.Body.Close()
|
_ = c.Request.Body.Close()
|
||||||
c.Set(ctxkey.KeyRequestBody, requestBody)
|
c.Set(ctxkey.KeyRequestBody, requestBody)
|
||||||
return requestBody.([]byte), nil
|
|
||||||
|
return requestBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||||
@ -30,19 +31,26 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "get request body failed")
|
return errors.Wrap(err, "get request body failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check v should be a pointer
|
||||||
|
if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr {
|
||||||
|
return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v))
|
||||||
|
}
|
||||||
|
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = json.Unmarshal(requestBody, &v)
|
err = json.Unmarshal(requestBody, v)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
} else {
|
} else {
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
err = c.ShouldBind(&v)
|
err = c.ShouldBind(v)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "unmarshal request body failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset request body
|
// Reset request body
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
go.mod
2
go.mod
@ -31,6 +31,7 @@ require (
|
|||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
golang.org/x/crypto v0.24.0
|
golang.org/x/crypto v0.24.0
|
||||||
golang.org/x/image v0.18.0
|
golang.org/x/image v0.18.0
|
||||||
|
golang.org/x/sync v0.7.0
|
||||||
google.golang.org/api v0.187.0
|
google.golang.org/api v0.187.0
|
||||||
gorm.io/driver/mysql v1.5.6
|
gorm.io/driver/mysql v1.5.6
|
||||||
gorm.io/driver/postgres v1.5.7
|
gorm.io/driver/postgres v1.5.7
|
||||||
@ -113,7 +114,6 @@ require (
|
|||||||
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
|
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
|
||||||
golang.org/x/net v0.26.0 // indirect
|
golang.org/x/net v0.26.0 // indirect
|
||||||
golang.org/x/oauth2 v0.21.0 // indirect
|
golang.org/x/oauth2 v0.21.0 // indirect
|
||||||
golang.org/x/sync v0.7.0 // indirect
|
|
||||||
golang.org/x/sys v0.21.0 // indirect
|
golang.org/x/sys v0.21.0 // indirect
|
||||||
golang.org/x/term v0.21.0 // indirect
|
golang.org/x/term v0.21.0 // indirect
|
||||||
golang.org/x/text v0.16.0 // indirect
|
golang.org/x/text v0.16.0 // indirect
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
@ -25,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) {
|
|||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed")
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"):
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = "text-moderation-stable"
|
modelRequest.Model = "text-moderation-stable"
|
||||||
}
|
}
|
||||||
}
|
case strings.HasSuffix(c.Request.URL.Path, "embeddings"):
|
||||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = c.Param("model")
|
modelRequest.Model = c.Param("model")
|
||||||
}
|
}
|
||||||
}
|
case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"),
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"):
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = "dall-e-2"
|
modelRequest.Model = "dall-e-2"
|
||||||
}
|
}
|
||||||
}
|
case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"),
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"):
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = "whisper-1"
|
modelRequest.Model = "whisper-1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelRequest.Model, nil
|
return modelRequest.Model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package replicate
|
package replicate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -9,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
@ -22,7 +24,31 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
return nil, errors.New("should call replicate.ConvertImageRequest instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
meta := meta.GetByContext(c)
|
||||||
|
|
||||||
|
if request.ResponseFormat != "b64_json" {
|
||||||
|
return nil, errors.New("only support b64_json response format")
|
||||||
|
}
|
||||||
|
if request.N != 1 && request.N != 0 {
|
||||||
|
return nil, errors.New("only support N=1")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
return convertImageCreateRequest(request)
|
||||||
|
case relaymode.ImagesEdits:
|
||||||
|
return convertImageRemixRequest(c)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||||
return DrawImageRequest{
|
return DrawImageRequest{
|
||||||
Input: ImageInput{
|
Input: ImageInput{
|
||||||
Steps: 25,
|
Steps: 25,
|
||||||
@ -38,6 +64,22 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertImageRemixRequest(c *gin.Context) (any, error) {
|
||||||
|
// recover request body
|
||||||
|
requestBody, err := common.GetRequestBody(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "get request body")
|
||||||
|
}
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
|
rawReq := new(OpenaiImageEditRequest)
|
||||||
|
if err := c.ShouldBind(rawReq); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parse image edit form")
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawReq.toFluxRemixRequest()
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@ -67,7 +109,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.ImagesGenerations:
|
case relaymode.ImagesGenerations,
|
||||||
|
relaymode.ImagesEdits:
|
||||||
err, usage = ImageHandler(c, resp)
|
err, usage = ImageHandler(c, resp)
|
||||||
default:
|
default:
|
||||||
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||||
|
@ -22,9 +22,9 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ImagesEditsHandler just copy response body to client
|
// // ImagesEditsHandler just copy response body to client
|
||||||
//
|
// //
|
||||||
// https://replicate.com/black-forest-labs/flux-fill-pro
|
// // https://replicate.com/black-forest-labs/flux-fill-pro
|
||||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
// c.Writer.WriteHeader(resp.StatusCode)
|
// c.Writer.WriteHeader(resp.StatusCode)
|
||||||
// for k, v := range resp.Header {
|
// for k, v := range resp.Header {
|
||||||
@ -32,7 +32,7 @@ import (
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
// return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
// }
|
// }
|
||||||
// defer resp.Body.Close()
|
// defer resp.Body.Close()
|
||||||
|
|
||||||
@ -41,7 +41,8 @@ import (
|
|||||||
|
|
||||||
var errNextLoop = errors.New("next_loop")
|
var errNextLoop = errors.New("next_loop")
|
||||||
|
|
||||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
func ImageHandler(c *gin.Context, resp *http.Response) (
|
||||||
|
*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
if resp.StatusCode != http.StatusCreated {
|
if resp.StatusCode != http.StatusCreated {
|
||||||
payload, _ := io.ReadAll(resp.Body)
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
return openai.ErrorWrapper(
|
return openai.ErrorWrapper(
|
||||||
@ -95,7 +96,7 @@ func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCo
|
|||||||
switch taskData.Status {
|
switch taskData.Status {
|
||||||
case "succeeded":
|
case "succeeded":
|
||||||
case "failed", "canceled":
|
case "failed", "canceled":
|
||||||
return errors.Errorf("task failed: %s", taskData.Status)
|
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||||
default:
|
default:
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
return errNextLoop
|
return errNextLoop
|
||||||
|
@ -1,11 +1,129 @@
|
|||||||
package replicate
|
package replicate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OpenaiImageEditRequest struct {
|
||||||
|
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
|
||||||
|
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
||||||
|
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
|
||||||
|
Model string `json:"model" form:"model" binding:"required"`
|
||||||
|
N int `json:"n" form:"n" binding:"min=0,max=10"`
|
||||||
|
Size string `json:"size" form:"size"`
|
||||||
|
ResponseFormat string `json:"response_format" form:"response_format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
|
||||||
|
//
|
||||||
|
// Note that the mask formats of OpenAI and Flux are different:
|
||||||
|
// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0),
|
||||||
|
// while Flux sets the parts to be modified as black (255, 255, 255, 255),
|
||||||
|
// so we need to convert the format here.
|
||||||
|
//
|
||||||
|
// Both OpenAI's Image and Mask are browser-native ImageData,
|
||||||
|
// which need to be converted to base64 dataURI format.
|
||||||
|
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
|
||||||
|
if r.ResponseFormat != "b64_json" {
|
||||||
|
return nil, errors.New("response_format must be b64_json for replicate models")
|
||||||
|
}
|
||||||
|
|
||||||
|
fluxReq := &InpaintingImageByFlusReplicateRequest{
|
||||||
|
Input: FluxInpaintingInput{
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
Seed: int(time.Now().UnixNano()),
|
||||||
|
Steps: 30,
|
||||||
|
Guidance: 3,
|
||||||
|
SafetyTolerance: 5,
|
||||||
|
PromptUnsampling: false,
|
||||||
|
OutputFormat: "png",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
imgFile, err := r.Image.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "open image file")
|
||||||
|
}
|
||||||
|
defer imgFile.Close()
|
||||||
|
imgData, err := io.ReadAll(imgFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "read image file")
|
||||||
|
}
|
||||||
|
|
||||||
|
maskFile, err := r.Mask.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "open mask file")
|
||||||
|
}
|
||||||
|
defer maskFile.Close()
|
||||||
|
|
||||||
|
// Convert image to base64
|
||||||
|
imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData)
|
||||||
|
fluxReq.Input.Image = imageBase64
|
||||||
|
|
||||||
|
// Convert mask data to RGBA
|
||||||
|
maskPNG, err := png.Decode(maskFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode mask file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert mask to RGBA
|
||||||
|
var maskRGBA *image.RGBA
|
||||||
|
switch converted := maskPNG.(type) {
|
||||||
|
case *image.RGBA:
|
||||||
|
maskRGBA = converted
|
||||||
|
default:
|
||||||
|
// Convert to RGBA
|
||||||
|
bounds := maskPNG.Bounds()
|
||||||
|
maskRGBA = image.NewRGBA(bounds)
|
||||||
|
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||||
|
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||||
|
maskRGBA.Set(x, y, maskPNG.At(x, y))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maskData := maskRGBA.Pix
|
||||||
|
invertedMask := make([]byte, len(maskData))
|
||||||
|
for i := 0; i+4 <= len(maskData); i += 4 {
|
||||||
|
// If pixel is transparent (alpha = 0), make it black (255)
|
||||||
|
if maskData[i+3] == 0 {
|
||||||
|
invertedMask[i] = 255 // R
|
||||||
|
invertedMask[i+1] = 255 // G
|
||||||
|
invertedMask[i+2] = 255 // B
|
||||||
|
invertedMask[i+3] = 255 // A
|
||||||
|
} else {
|
||||||
|
// Copy original pixel
|
||||||
|
copy(invertedMask[i:i+4], maskData[i:i+4])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert inverted mask to base64 encoded png image
|
||||||
|
invertedMaskRGBA := &image.RGBA{
|
||||||
|
Pix: invertedMask,
|
||||||
|
Stride: maskRGBA.Stride,
|
||||||
|
Rect: maskRGBA.Rect,
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err = png.Encode(&buf, invertedMaskRGBA)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode inverted mask to png")
|
||||||
|
}
|
||||||
|
|
||||||
|
invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||||
|
fluxReq.Input.Mask = invertedMaskBase64
|
||||||
|
|
||||||
|
return fluxReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DrawImageRequest draw image by fluxpro
|
// DrawImageRequest draw image by fluxpro
|
||||||
//
|
//
|
||||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
108
relay/adaptor/replicate/model_test.go
Normal file
108
relay/adaptor/replicate/model_test.go
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToFluxRemixRequest(t *testing.T) {
|
||||||
|
// Prepare input data
|
||||||
|
imageData := []byte{0x89, 0x50, 0x4E, 0x47} // Simulates PNG magic bytes
|
||||||
|
maskData := []byte{
|
||||||
|
0, 0, 0, 0, // Transparent pixel
|
||||||
|
255, 255, 255, 255, // Opaque white pixel
|
||||||
|
}
|
||||||
|
prompt := "Test prompt"
|
||||||
|
model := "Test model"
|
||||||
|
responseType := "json"
|
||||||
|
|
||||||
|
// convert image and mask to FileHeader
|
||||||
|
imageFileHeader, err := createFileHeader("image", "image.png", imageData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
maskFileHeader, err := createFileHeader("mask", "mask.png", maskData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
request := OpenaiImageEditRequest{
|
||||||
|
Image: imageFileHeader,
|
||||||
|
Mask: maskFileHeader,
|
||||||
|
Prompt: prompt,
|
||||||
|
Model: model,
|
||||||
|
ResponseFormat: responseType,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the method under test
|
||||||
|
fluxRequest, err := request.toFluxRemixRequest()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify FluxInpaintingInput fields
|
||||||
|
require.NotNil(t, fluxRequest)
|
||||||
|
require.Equal(t, prompt, fluxRequest.Input.Prompt)
|
||||||
|
require.Equal(t, 30, fluxRequest.Input.Steps)
|
||||||
|
require.Equal(t, 3, fluxRequest.Input.Guidance)
|
||||||
|
require.Equal(t, 5, fluxRequest.Input.SafetyTolerance)
|
||||||
|
require.False(t, fluxRequest.Input.PromptUnsampling)
|
||||||
|
|
||||||
|
// Check image field (Base64 encoded)
|
||||||
|
expectedImageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
require.Equal(t, expectedImageBase64, fluxRequest.Input.Image)
|
||||||
|
|
||||||
|
// Check mask field (Base64 encoded and inverted transparency)
|
||||||
|
expectedInvertedMask := []byte{
|
||||||
|
255, 255, 255, 255, // Transparent pixel inverted to black
|
||||||
|
255, 255, 255, 255, // Opaque white pixel remains the same
|
||||||
|
}
|
||||||
|
expectedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(expectedInvertedMask)
|
||||||
|
require.Equal(t, expectedMaskBase64, fluxRequest.Input.Mask)
|
||||||
|
|
||||||
|
// Verify seed
|
||||||
|
// Since the seed is generated based on the current time, we validate its presence
|
||||||
|
require.NotZero(t, fluxRequest.Input.Seed)
|
||||||
|
require.True(t, fluxRequest.Input.Seed > 0)
|
||||||
|
|
||||||
|
// Additional assertions can be added as necessary
|
||||||
|
}
|
||||||
|
|
||||||
|
// createFileHeader creates a multipart.FileHeader from file bytes
|
||||||
|
func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) {
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
|
// Create a form file field
|
||||||
|
part, err := writer.CreateFormFile(fieldname, filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the file bytes to the form file field
|
||||||
|
_, err = part.Write(fileBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the writer to finalize the form
|
||||||
|
err = writer.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the multipart form
|
||||||
|
req := &http.Request{
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: io.NopCloser(body),
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
err = req.ParseMultipartForm(int64(body.Len()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve the file header from the parsed form
|
||||||
|
fileHeader := req.MultipartForm.File[fieldname][0]
|
||||||
|
return fileHeader, nil
|
||||||
|
}
|
@ -5,6 +5,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Laisky/errors/v2"
|
"github.com/Laisky/errors/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
@ -13,20 +17,18 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay"
|
"github.com/songquanpeng/one-api/relay"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
||||||
imageRequest := &relaymodel.ImageRequest{}
|
imageRequest := &relaymodel.ImageRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
if imageRequest.N == 0 {
|
if imageRequest.N == 0 {
|
||||||
imageRequest.N = 1
|
imageRequest.N = 1
|
||||||
@ -155,7 +157,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Zhipu,
|
case channeltype.Zhipu,
|
||||||
channeltype.Ali,
|
channeltype.Ali,
|
||||||
channeltype.Replicate,
|
|
||||||
channeltype.Baidu:
|
channeltype.Baidu:
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -166,6 +167,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case channeltype.Replicate:
|
||||||
|
finalRequest, err := replicate.ConvertImageRequest(c, imageRequest)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
jsonStr, err := json.Marshal(finalRequest)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
|
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
|
||||||
|
@ -2,7 +2,7 @@ package model
|
|||||||
|
|
||||||
type ImageRequest struct {
|
type ImageRequest struct {
|
||||||
Model string `json:"model" form:"model"`
|
Model string `json:"model" form:"model"`
|
||||||
Prompt string `json:"prompt" binding:"required" form:"prompt"`
|
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
||||||
N int `json:"n,omitempty" form:"n"`
|
N int `json:"n,omitempty" form:"n"`
|
||||||
Size string `json:"size,omitempty" form:"size"`
|
Size string `json:"size,omitempty" form:"size"`
|
||||||
Quality string `json:"quality,omitempty" form:"quality"`
|
Quality string `json:"quality,omitempty" form:"quality"`
|
||||||
|
Loading…
Reference in New Issue
Block a user