mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
Merge remote-tracking branch 'origin/patch/replicate-image' into patch/replicate-flux-inpainting
This commit is contained in:
commit
f6b4ca3936
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,3 +10,4 @@ data
|
|||||||
/web/node_modules
|
/web/node_modules
|
||||||
cmd.md
|
cmd.md
|
||||||
.env
|
.env
|
||||||
|
/one-api
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||||
@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
|||||||
return &vertexai.Adaptor{}
|
return &vertexai.Adaptor{}
|
||||||
case apitype.Proxy:
|
case apitype.Proxy:
|
||||||
return &proxy.Adaptor{}
|
return &proxy.Adaptor{}
|
||||||
|
case apitype.Replicate:
|
||||||
|
return &replicate.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,16 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/model"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||||
|
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||||
|
|
||||||
Error := model.Error{
|
Error := model.Error{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "one_api_error",
|
||||||
|
85
relay/adaptor/replicate/adaptor.go
Normal file
85
relay/adaptor/replicate/adaptor.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
meta *meta.Meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
|
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
return DrawImageRequest{
|
||||||
|
Input: ImageInput{
|
||||||
|
Steps: 25,
|
||||||
|
Prompt: request.Prompt,
|
||||||
|
Guidance: 3,
|
||||||
|
Seed: int(time.Now().UnixNano()),
|
||||||
|
SafetyTolerance: 5,
|
||||||
|
NImages: 1, // replicate will always return 1 image
|
||||||
|
Width: 1440,
|
||||||
|
Height: 1440,
|
||||||
|
AspectRatio: "1:1",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
|
a.meta = meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
|
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||||
|
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||||
|
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
logger.Info(c, "send image request to replicate")
|
||||||
|
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
err, usage = ImageHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "replicate"
|
||||||
|
}
|
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
// ModelList is a list of models that can be used with Replicate.
|
||||||
|
//
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
var ModelList = []string{
|
||||||
|
// -------------------------------------
|
||||||
|
// image model
|
||||||
|
// -------------------------------------
|
||||||
|
"black-forest-labs/flux-1.1-pro",
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra",
|
||||||
|
"black-forest-labs/flux-canny-dev",
|
||||||
|
"black-forest-labs/flux-canny-pro",
|
||||||
|
"black-forest-labs/flux-depth-dev",
|
||||||
|
"black-forest-labs/flux-depth-pro",
|
||||||
|
"black-forest-labs/flux-dev",
|
||||||
|
"black-forest-labs/flux-dev-lora",
|
||||||
|
"black-forest-labs/flux-fill-dev",
|
||||||
|
"black-forest-labs/flux-fill-pro",
|
||||||
|
"black-forest-labs/flux-pro",
|
||||||
|
"black-forest-labs/flux-redux-dev",
|
||||||
|
"black-forest-labs/flux-redux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell-lora",
|
||||||
|
"ideogram-ai/ideogram-v2",
|
||||||
|
"ideogram-ai/ideogram-v2-turbo",
|
||||||
|
"recraft-ai/recraft-v3",
|
||||||
|
"recraft-ai/recraft-v3-svg",
|
||||||
|
"stability-ai/stable-diffusion-3",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium",
|
||||||
|
// -------------------------------------
|
||||||
|
// language model
|
||||||
|
// -------------------------------------
|
||||||
|
// "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor
|
||||||
|
// "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor
|
||||||
|
// "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor
|
||||||
|
// "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor
|
||||||
|
// "meta/llama-2-13b", // TODO: implement the adaptor
|
||||||
|
// "meta/llama-2-13b-chat", // TODO: implement the adaptor
|
||||||
|
// "meta/llama-2-70b", // TODO: implement the adaptor
|
||||||
|
// "meta/llama-2-70b-chat", // TODO: implement the adaptor
|
||||||
|
// "meta/llama-2-7b", // TODO: implement the adaptor
|
||||||
|
// "meta/llama-2-7b-chat", // TODO: implement the adaptor
|
||||||
|
// "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor
|
||||||
|
// "meta/meta-llama-3-70b", // TODO: implement the adaptor
|
||||||
|
// "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor
|
||||||
|
// "meta/meta-llama-3-8b", // TODO: implement the adaptor
|
||||||
|
// "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor
|
||||||
|
// "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor
|
||||||
|
// "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor
|
||||||
|
// "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor
|
||||||
|
// -------------------------------------
|
||||||
|
// video model
|
||||||
|
// -------------------------------------
|
||||||
|
// "minimax/video-01", // TODO: implement the adaptor
|
||||||
|
}
|
222
relay/adaptor/replicate/image.go
Normal file
222
relay/adaptor/replicate/image.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImagesEditsHandler just copy response body to client
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro
|
||||||
|
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
// c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
// for k, v := range resp.Header {
|
||||||
|
// c.Writer.Header().Set(k, v[0])
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
|
||||||
|
// return nil, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
var errNextLoop = errors.New("next_loop")
|
||||||
|
|
||||||
|
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed: %s", taskData.Status)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := taskData.GetOutput()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get output")
|
||||||
|
}
|
||||||
|
if len(output) == 0 {
|
||||||
|
return errors.New("response output is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var pool errgroup.Group
|
||||||
|
respBody := &openai.ImageResponse{
|
||||||
|
Created: taskData.CompletedAt.Unix(),
|
||||||
|
Data: []openai.ImageData{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, imgOut := range output {
|
||||||
|
imgOut := imgOut
|
||||||
|
pool.Go(func() error {
|
||||||
|
// download image
|
||||||
|
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, imgOut, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "download image")
|
||||||
|
}
|
||||||
|
defer imgResp.Body.Close()
|
||||||
|
|
||||||
|
if imgResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(imgResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
imgResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err := io.ReadAll(imgResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read image")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err = ConvertImageToPNG(imgData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "convert image")
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||||
|
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||||
|
base64.StdEncoding.EncodeToString(imgData)),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.Wait(); err != nil {
|
||||||
|
if len(respBody.Data) == 0 {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, respBody)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageToPNG converts a WebP image to PNG format
|
||||||
|
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||||
|
// bypass if it's already a PNG image
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||||
|
return webpData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if is jpeg, convert to png
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode jpeg")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the WebP image
|
||||||
|
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode webp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the image as PNG
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
111
relay/adaptor/replicate/model.go
Normal file
111
relay/adaptor/replicate/model.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DrawImageRequest draw image by fluxpro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type DrawImageRequest struct {
|
||||||
|
Input ImageInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||||
|
type ImageInput struct {
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
ImagePrompt string `json:"image_prompt"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||||
|
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||||
|
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||||
|
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type InpaintingImageByFlusReplicateRequest struct {
|
||||||
|
Input FluxInpaintingInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type FluxInpaintingInput struct {
|
||||||
|
Mask string `json:"mask" binding:"required"`
|
||||||
|
Image string `json:"image" binding:"required"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
OutputFormat string `json:"output_format"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponse is response of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type ImageResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input DrawImageRequest `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output any `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs FluxURLs `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||||
|
switch v := r.Output.(type) {
|
||||||
|
case string:
|
||||||
|
return []string{v}, nil
|
||||||
|
case []string:
|
||||||
|
return v, nil
|
||||||
|
case nil:
|
||||||
|
return nil, nil
|
||||||
|
case []interface{}:
|
||||||
|
// convert []interface{} to []string
|
||||||
|
ret := make([]string, len(v))
|
||||||
|
for idx, vv := range v {
|
||||||
|
if vvv, ok := vv.(string); ok {
|
||||||
|
ret[idx] = vvv
|
||||||
|
} else {
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxMetrics is metrics of ImageResponse
|
||||||
|
type FluxMetrics struct {
|
||||||
|
ImageCount int `json:"image_count"`
|
||||||
|
PredictTime float64 `json:"predict_time"`
|
||||||
|
TotalTime float64 `json:"total_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxURLs is urls of ImageResponse
|
||||||
|
type FluxURLs struct {
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
@ -19,6 +19,7 @@ const (
|
|||||||
DeepL
|
DeepL
|
||||||
VertexAI
|
VertexAI
|
||||||
Proxy
|
Proxy
|
||||||
|
Replicate
|
||||||
|
|
||||||
Dummy // this one is only for count, do not add any channel after this
|
Dummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
@ -211,6 +211,31 @@ var ModelRatio = map[string]float64{
|
|||||||
"deepl-ja": 25.0 / 1000 * USD,
|
"deepl-ja": 25.0 / 1000 * USD,
|
||||||
// https://console.x.ai/
|
// https://console.x.ai/
|
||||||
"grok-beta": 5.0 / 1000 * USD,
|
"grok-beta": 5.0 / 1000 * USD,
|
||||||
|
// replicate charges based on the number of generated images
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
|
||||||
|
"black-forest-labs/flux-canny-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-canny-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-depth-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-depth-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-dev-lora": 0.032 * USD,
|
||||||
|
"black-forest-labs/flux-fill-dev": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-fill-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-pro": 0.055 * USD,
|
||||||
|
"black-forest-labs/flux-redux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2": 0.08 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
|
||||||
|
"recraft-ai/recraft-v3": 0.04 * USD,
|
||||||
|
"recraft-ai/recraft-v3-svg": 0.08 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3": 0.035 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{
|
var CompletionRatio = map[string]float64{
|
||||||
|
@ -47,5 +47,6 @@ const (
|
|||||||
Proxy
|
Proxy
|
||||||
SiliconFlow
|
SiliconFlow
|
||||||
XAI
|
XAI
|
||||||
|
Replicate
|
||||||
Dummy
|
Dummy
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
|
|||||||
apiType = apitype.DeepL
|
apiType = apitype.DeepL
|
||||||
case VertextAI:
|
case VertextAI:
|
||||||
apiType = apitype.VertexAI
|
apiType = apitype.VertexAI
|
||||||
|
case Replicate:
|
||||||
|
apiType = apitype.Replicate
|
||||||
case Proxy:
|
case Proxy:
|
||||||
apiType = apitype.Proxy
|
apiType = apitype.Proxy
|
||||||
}
|
}
|
||||||
|
@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"", // 43
|
"", // 43
|
||||||
"https://api.siliconflow.cn", // 44
|
"https://api.siliconflow.cn", // 44
|
||||||
"https://api.x.ai", // 45
|
"https://api.x.ai", // 45
|
||||||
|
"https://api.replicate.com/v1/models/", // 46
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -152,12 +152,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
|
// these adaptors need to convert the request
|
||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Ali:
|
case channeltype.Zhipu,
|
||||||
fallthrough
|
channeltype.Ali,
|
||||||
case channeltype.Baidu:
|
channeltype.Replicate,
|
||||||
fallthrough
|
channeltype.Baidu:
|
||||||
case channeltype.Zhipu:
|
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
@ -174,7 +174,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||||
|
|
||||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
var quota int64
|
||||||
|
switch meta.ChannelType {
|
||||||
|
case channeltype.Replicate:
|
||||||
|
// replicate always return 1 image
|
||||||
|
quota = int64(ratio * imageCostRatio * 1000)
|
||||||
|
default:
|
||||||
|
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||||
|
}
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
@ -188,7 +195,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil &&
|
||||||
|
resp.StatusCode != http.StatusCreated && // replicate returns 201
|
||||||
|
resp.StatusCode != http.StatusOK {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 45,
|
value: 45,
|
||||||
color: 'primary'
|
color: 'primary'
|
||||||
},
|
},
|
||||||
|
45: {
|
||||||
|
key: 46,
|
||||||
|
text: 'Replicate',
|
||||||
|
value: 46,
|
||||||
|
color: 'primary'
|
||||||
|
},
|
||||||
41: {
|
41: {
|
||||||
key: 41,
|
key: 41,
|
||||||
text: 'Novita',
|
text: 'Novita',
|
||||||
|
@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
Loading…
Reference in New Issue
Block a user