mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
Merge remote-tracking branch 'origin/patch/replicate-image' into patch/replicate-flux-inpainting
This commit is contained in:
commit
f6b4ca3936
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@ -1,13 +1,13 @@
|
||||
name: CI
|
||||
|
||||
# This setup assumes that you run the unit tests with code coverage in the same
|
||||
# workflow that will also print the coverage report as comment to the pull request.
|
||||
# workflow that will also print the coverage report as comment to the pull request.
|
||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
||||
# when new code is pushed to the branch of the pull request. In addition, you also
|
||||
# need to trigger this workflow when new code is pushed to the main branch because
|
||||
# need to trigger this workflow when new code is pushed to the main branch because
|
||||
# we need to upload the code coverage results as artifact for the main branch as
|
||||
# well since it will be the baseline code coverage.
|
||||
#
|
||||
#
|
||||
# We do not want to trigger the workflow for pushes to *any* branch because this
|
||||
# would trigger our jobs twice on pull requests (once from "push" event and once
|
||||
# from "pull_request->synchronize")
|
||||
@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
go-version: ^1.22
|
||||
|
||||
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
||||
# in the next step as well as the next job.
|
||||
- name: Test
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -9,4 +9,5 @@ logs
|
||||
data
|
||||
/web/node_modules
|
||||
cmd.md
|
||||
.env
|
||||
.env
|
||||
/one-api
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||
@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
||||
return &vertexai.Adaptor{}
|
||||
case apitype.Proxy:
|
||||
return &proxy.Adaptor{}
|
||||
case apitype.Replicate:
|
||||
return &replicate.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1,8 +1,16 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||
|
||||
Error := model.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
|
85
relay/adaptor/replicate/adaptor.go
Normal file
85
relay/adaptor/replicate/adaptor.go
Normal file
@ -0,0 +1,85 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
return DrawImageRequest{
|
||||
Input: ImageInput{
|
||||
Steps: 25,
|
||||
Prompt: request.Prompt,
|
||||
Guidance: 3,
|
||||
Seed: int(time.Now().UnixNano()),
|
||||
SafetyTolerance: 5,
|
||||
NImages: 1, // replicate will always return 1 image
|
||||
Width: 1440,
|
||||
Height: 1440,
|
||||
AspectRatio: "1:1",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
logger.Info(c, "send image request to replicate")
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
default:
|
||||
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "replicate"
|
||||
}
|
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@ -0,0 +1,58 @@
|
||||
package replicate
|
||||
|
||||
// ModelList is a list of models that can be used with Replicate.
|
||||
//
|
||||
// https://replicate.com/pricing
|
||||
var ModelList = []string{
|
||||
// -------------------------------------
|
||||
// image model
|
||||
// -------------------------------------
|
||||
"black-forest-labs/flux-1.1-pro",
|
||||
"black-forest-labs/flux-1.1-pro-ultra",
|
||||
"black-forest-labs/flux-canny-dev",
|
||||
"black-forest-labs/flux-canny-pro",
|
||||
"black-forest-labs/flux-depth-dev",
|
||||
"black-forest-labs/flux-depth-pro",
|
||||
"black-forest-labs/flux-dev",
|
||||
"black-forest-labs/flux-dev-lora",
|
||||
"black-forest-labs/flux-fill-dev",
|
||||
"black-forest-labs/flux-fill-pro",
|
||||
"black-forest-labs/flux-pro",
|
||||
"black-forest-labs/flux-redux-dev",
|
||||
"black-forest-labs/flux-redux-schnell",
|
||||
"black-forest-labs/flux-schnell",
|
||||
"black-forest-labs/flux-schnell-lora",
|
||||
"ideogram-ai/ideogram-v2",
|
||||
"ideogram-ai/ideogram-v2-turbo",
|
||||
"recraft-ai/recraft-v3",
|
||||
"recraft-ai/recraft-v3-svg",
|
||||
"stability-ai/stable-diffusion-3",
|
||||
"stability-ai/stable-diffusion-3.5-large",
|
||||
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||
"stability-ai/stable-diffusion-3.5-medium",
|
||||
// -------------------------------------
|
||||
// language model
|
||||
// -------------------------------------
|
||||
// "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor
|
||||
// "meta/llama-2-13b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-13b-chat", // TODO: implement the adaptor
|
||||
// "meta/llama-2-70b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-70b-chat", // TODO: implement the adaptor
|
||||
// "meta/llama-2-7b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-7b-chat", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-70b", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-8b", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor
|
||||
// "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor
|
||||
// "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor
|
||||
// "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor
|
||||
// -------------------------------------
|
||||
// video model
|
||||
// -------------------------------------
|
||||
// "minimax/video-01", // TODO: implement the adaptor
|
||||
}
|
222
relay/adaptor/replicate/image.go
Normal file
222
relay/adaptor/replicate/image.go
Normal file
@ -0,0 +1,222 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"golang.org/x/image/webp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// ImagesEditsHandler just copy response body to client
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro
|
||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// for k, v := range resp.Header {
|
||||
// c.Writer.Header().Set(k, v[0])
|
||||
// }
|
||||
|
||||
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
var errNextLoop = errors.New("next_loop")
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return openai.ErrorWrapper(
|
||||
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||
"bad_status_code", http.StatusInternalServerError),
|
||||
nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
respData := new(ImageResponse)
|
||||
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
for {
|
||||
err = func() error {
|
||||
// get task
|
||||
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, respData.URLs.Get, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get task")
|
||||
}
|
||||
defer taskResp.Body.Close()
|
||||
|
||||
if taskResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(taskResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
taskResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
taskBody, err := io.ReadAll(taskResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read task response")
|
||||
}
|
||||
|
||||
taskData := new(ImageResponse)
|
||||
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||
return errors.Wrap(err, "decode task response")
|
||||
}
|
||||
|
||||
switch taskData.Status {
|
||||
case "succeeded":
|
||||
case "failed", "canceled":
|
||||
return errors.Errorf("task failed: %s", taskData.Status)
|
||||
default:
|
||||
time.Sleep(time.Second * 3)
|
||||
return errNextLoop
|
||||
}
|
||||
|
||||
output, err := taskData.GetOutput()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get output")
|
||||
}
|
||||
if len(output) == 0 {
|
||||
return errors.New("response output is empty")
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var pool errgroup.Group
|
||||
respBody := &openai.ImageResponse{
|
||||
Created: taskData.CompletedAt.Unix(),
|
||||
Data: []openai.ImageData{},
|
||||
}
|
||||
|
||||
for _, imgOut := range output {
|
||||
imgOut := imgOut
|
||||
pool.Go(func() error {
|
||||
// download image
|
||||
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, imgOut, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "download image")
|
||||
}
|
||||
defer imgResp.Body.Close()
|
||||
|
||||
if imgResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(imgResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
imgResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
imgData, err := io.ReadAll(imgResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read image")
|
||||
}
|
||||
|
||||
imgData, err = ConvertImageToPNG(imgData)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "convert image")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||
base64.StdEncoding.EncodeToString(imgData)),
|
||||
})
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := pool.Wait(); err != nil {
|
||||
if len(respBody.Data) == 0 {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, respBody)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
if errors.Is(err, errNextLoop) {
|
||||
continue
|
||||
}
|
||||
|
||||
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ConvertImageToPNG converts a WebP image to PNG format
|
||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||
// bypass if it's already a PNG image
|
||||
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||
return webpData, nil
|
||||
}
|
||||
|
||||
// check if is jpeg, convert to png
|
||||
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode jpeg")
|
||||
}
|
||||
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// Decode the WebP image
|
||||
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode webp")
|
||||
}
|
||||
|
||||
// Encode the image as PNG
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
111
relay/adaptor/replicate/model.go
Normal file
111
relay/adaptor/replicate/model.go
Normal file
@ -0,0 +1,111 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DrawImageRequest draw image by fluxpro
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||
type DrawImageRequest struct {
|
||||
Input ImageInput `json:"input"`
|
||||
}
|
||||
|
||||
// ImageInput is input of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||
type ImageInput struct {
|
||||
Steps int `json:"steps" binding:"required,min=1"`
|
||||
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||
ImagePrompt string `json:"image_prompt"`
|
||||
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||
Seed int `json:"seed"`
|
||||
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||
}
|
||||
|
||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||
type InpaintingImageByFlusReplicateRequest struct {
|
||||
Input FluxInpaintingInput `json:"input"`
|
||||
}
|
||||
|
||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||
type FluxInpaintingInput struct {
|
||||
Mask string `json:"mask" binding:"required"`
|
||||
Image string `json:"image" binding:"required"`
|
||||
Seed int `json:"seed"`
|
||||
Steps int `json:"steps" binding:"required,min=1"`
|
||||
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||
OutputFormat string `json:"output_format"`
|
||||
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||
}
|
||||
|
||||
// ImageResponse is response of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||
type ImageResponse struct {
|
||||
CompletedAt time.Time `json:"completed_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DataRemoved bool `json:"data_removed"`
|
||||
Error string `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Input DrawImageRequest `json:"input"`
|
||||
Logs string `json:"logs"`
|
||||
Metrics FluxMetrics `json:"metrics"`
|
||||
// Output could be `string` or `[]string`
|
||||
Output any `json:"output"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Status string `json:"status"`
|
||||
URLs FluxURLs `json:"urls"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||
switch v := r.Output.(type) {
|
||||
case string:
|
||||
return []string{v}, nil
|
||||
case []string:
|
||||
return v, nil
|
||||
case nil:
|
||||
return nil, nil
|
||||
case []interface{}:
|
||||
// convert []interface{} to []string
|
||||
ret := make([]string, len(v))
|
||||
for idx, vv := range v {
|
||||
if vvv, ok := vv.(string); ok {
|
||||
ret[idx] = vvv
|
||||
} else {
|
||||
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// FluxMetrics is metrics of ImageResponse
|
||||
type FluxMetrics struct {
|
||||
ImageCount int `json:"image_count"`
|
||||
PredictTime float64 `json:"predict_time"`
|
||||
TotalTime float64 `json:"total_time"`
|
||||
}
|
||||
|
||||
// FluxURLs is urls of ImageResponse
|
||||
type FluxURLs struct {
|
||||
Get string `json:"get"`
|
||||
Cancel string `json:"cancel"`
|
||||
}
|
@ -19,6 +19,7 @@ const (
|
||||
DeepL
|
||||
VertexAI
|
||||
Proxy
|
||||
Replicate
|
||||
|
||||
Dummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
@ -211,6 +211,31 @@ var ModelRatio = map[string]float64{
|
||||
"deepl-ja": 25.0 / 1000 * USD,
|
||||
// https://console.x.ai/
|
||||
"grok-beta": 5.0 / 1000 * USD,
|
||||
// replicate charges based on the number of generated images
|
||||
// https://replicate.com/pricing
|
||||
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
|
||||
"black-forest-labs/flux-canny-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-canny-pro": 0.05 * USD,
|
||||
"black-forest-labs/flux-depth-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-depth-pro": 0.05 * USD,
|
||||
"black-forest-labs/flux-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-dev-lora": 0.032 * USD,
|
||||
"black-forest-labs/flux-fill-dev": 0.04 * USD,
|
||||
"black-forest-labs/flux-fill-pro": 0.05 * USD,
|
||||
"black-forest-labs/flux-pro": 0.055 * USD,
|
||||
"black-forest-labs/flux-redux-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
|
||||
"black-forest-labs/flux-schnell": 0.003 * USD,
|
||||
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
|
||||
"ideogram-ai/ideogram-v2": 0.08 * USD,
|
||||
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
|
||||
"recraft-ai/recraft-v3": 0.04 * USD,
|
||||
"recraft-ai/recraft-v3-svg": 0.08 * USD,
|
||||
"stability-ai/stable-diffusion-3": 0.035 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||
}
|
||||
|
||||
var CompletionRatio = map[string]float64{
|
||||
|
@ -47,5 +47,6 @@ const (
|
||||
Proxy
|
||||
SiliconFlow
|
||||
XAI
|
||||
Replicate
|
||||
Dummy
|
||||
)
|
||||
|
@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
|
||||
apiType = apitype.DeepL
|
||||
case VertextAI:
|
||||
apiType = apitype.VertexAI
|
||||
case Replicate:
|
||||
apiType = apitype.Replicate
|
||||
case Proxy:
|
||||
apiType = apitype.Proxy
|
||||
}
|
||||
|
@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
|
||||
"", // 43
|
||||
"https://api.siliconflow.cn", // 44
|
||||
"https://api.x.ai", // 45
|
||||
"https://api.replicate.com/v1/models/", // 46
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -152,12 +152,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
|
||||
// these adaptors need to convert the request
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Ali:
|
||||
fallthrough
|
||||
case channeltype.Baidu:
|
||||
fallthrough
|
||||
case channeltype.Zhipu:
|
||||
case channeltype.Zhipu,
|
||||
channeltype.Ali,
|
||||
channeltype.Replicate,
|
||||
channeltype.Baidu:
|
||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||
@ -174,7 +174,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
var quota int64
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Replicate:
|
||||
// replicate always return 1 image
|
||||
quota = int64(ratio * imageCostRatio * 1000)
|
||||
default:
|
||||
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
}
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
@ -188,7 +195,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
if resp != nil &&
|
||||
resp.StatusCode != http.StatusCreated && // replicate returns 201
|
||||
resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
|
@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 45,
|
||||
color: 'primary'
|
||||
},
|
||||
45: {
|
||||
key: 46,
|
||||
text: 'Replicate',
|
||||
value: 46,
|
||||
color: 'primary'
|
||||
},
|
||||
41: {
|
||||
key: 41,
|
||||
text: 'Novita',
|
||||
|
@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
|
Loading…
Reference in New Issue
Block a user