mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: add Replicate adaptor and integrate into channel and API types
This commit is contained in:
		
							
								
								
									
										3
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							@@ -5,7 +5,8 @@ on:
 | 
			
		||||
    branches:
 | 
			
		||||
      - "master"
 | 
			
		||||
      - "main"
 | 
			
		||||
      - "test/ci"
 | 
			
		||||
      # - "test/ci"
 | 
			
		||||
      # - "feature/flux"
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  build_latest:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -11,3 +11,4 @@ node_modules
 | 
			
		||||
/web/node_modules
 | 
			
		||||
cmd.md
 | 
			
		||||
.env
 | 
			
		||||
/one-api
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/palm"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/proxy"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/replicate"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/tencent"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
 | 
			
		||||
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
 | 
			
		||||
		return &vertexai.Adaptor{}
 | 
			
		||||
	case apitype.Proxy:
 | 
			
		||||
		return &proxy.Adaptor{}
 | 
			
		||||
	case apitype.Replicate:
 | 
			
		||||
		return &replicate.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,16 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
 | 
			
		||||
	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
 | 
			
		||||
 | 
			
		||||
	Error := model.Error{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										83
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,83 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
	meta *meta.Meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return DrawImageRequest{
 | 
			
		||||
		Input: ImageInput{
 | 
			
		||||
			Steps:           25,
 | 
			
		||||
			Prompt:          request.Prompt,
 | 
			
		||||
			Guidance:        3,
 | 
			
		||||
			Seed:            int(time.Now().UnixNano()),
 | 
			
		||||
			SafetyTolerance: 5,
 | 
			
		||||
			NImages:         request.N,
 | 
			
		||||
			Width:           1440,
 | 
			
		||||
			Height:          1440,
 | 
			
		||||
			AspectRatio:     "1:1",
 | 
			
		||||
		},
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
	a.meta = meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if !slices.Contains(ModelList, meta.OriginModelName) {
 | 
			
		||||
		return "", errors.Errorf("model %s not supported", meta.OriginModelName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
 | 
			
		||||
	adaptor.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+meta.APIKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return adaptor.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		err, usage = ImageHandler(c, resp)
 | 
			
		||||
	default:
 | 
			
		||||
		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return "replicate"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
// ModelList is a list of models that can be used with Replicate.
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/pricing
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// image model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro",
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra",
 | 
			
		||||
	"black-forest-labs/flux-canny-dev",
 | 
			
		||||
	"black-forest-labs/flux-canny-pro",
 | 
			
		||||
	"black-forest-labs/flux-depth-dev",
 | 
			
		||||
	"black-forest-labs/flux-depth-pro",
 | 
			
		||||
	"black-forest-labs/flux-dev",
 | 
			
		||||
	"black-forest-labs/flux-dev-lora",
 | 
			
		||||
	"black-forest-labs/flux-fill-dev",
 | 
			
		||||
	"black-forest-labs/flux-fill-pro",
 | 
			
		||||
	"black-forest-labs/flux-pro",
 | 
			
		||||
	"black-forest-labs/flux-redux-dev",
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora",
 | 
			
		||||
	"ideogram-ai/ideogram-v2",
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo",
 | 
			
		||||
	"recraft-ai/recraft-v3",
 | 
			
		||||
	"recraft-ai/recraft-v3-svg",
 | 
			
		||||
	"stability-ai/stable-diffusion-3",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// language model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// "ibm-granite/granite-20b-code-instruct-8k",  // TODO: implement the adaptor
 | 
			
		||||
	// "ibm-granite/granite-3.0-2b-instruct",  // TODO: implement the adaptor
 | 
			
		||||
	// "ibm-granite/granite-3.0-8b-instruct",  // TODO: implement the adaptor
 | 
			
		||||
	// "ibm-granite/granite-8b-code-instruct-128k",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/llama-2-13b",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/llama-2-13b-chat",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/llama-2-70b",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/llama-2-70b-chat",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/llama-2-7b",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/llama-2-7b-chat",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/meta-llama-3.1-405b-instruct",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/meta-llama-3-70b",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/meta-llama-3-70b-instruct",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/meta-llama-3-8b",  // TODO: implement the adaptor
 | 
			
		||||
	// "meta/meta-llama-3-8b-instruct",  // TODO: implement the adaptor
 | 
			
		||||
	// "mistralai/mistral-7b-instruct-v0.2",  // TODO: implement the adaptor
 | 
			
		||||
	// "mistralai/mistral-7b-v0.1",  // TODO: implement the adaptor
 | 
			
		||||
	// "mistralai/mixtral-8x7b-instruct-v0.1",  // TODO: implement the adaptor
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// video model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// "minimax/video-01",  // TODO: implement the adaptor
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										217
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,217 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/png"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"golang.org/x/image/webp"
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ImagesEditsHandler just copy response body to client
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro
 | 
			
		||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
// 	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
// 	for k, v := range resp.Header {
 | 
			
		||||
// 		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
// 	}
 | 
			
		||||
 | 
			
		||||
// 	if _, err := io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
// 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
// 	}
 | 
			
		||||
// 	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
// 	return nil, nil
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ImageResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Debug(c, "send image request to replicate")
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ImageResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed: %s", taskData.Status)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			output, err := taskData.GetOutput()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get output")
 | 
			
		||||
			}
 | 
			
		||||
			if len(output) == 0 {
 | 
			
		||||
				return errors.New("response output is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var mu sync.Mutex
 | 
			
		||||
			var pool errgroup.Group
 | 
			
		||||
			respBody := &openai.ImageResponse{
 | 
			
		||||
				Created: taskData.CompletedAt.Unix(),
 | 
			
		||||
				Data:    []openai.ImageData{},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, imgOut := range output {
 | 
			
		||||
				imgOut := imgOut
 | 
			
		||||
				pool.Go(func() error {
 | 
			
		||||
					// download image
 | 
			
		||||
					downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
						http.MethodGet, imgOut, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "new request")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgResp, err := http.DefaultClient.Do(downloadReq)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "download image")
 | 
			
		||||
					}
 | 
			
		||||
					defer imgResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
					if imgResp.StatusCode != http.StatusOK {
 | 
			
		||||
						payload, _ := io.ReadAll(imgResp.Body)
 | 
			
		||||
						return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
							imgResp.StatusCode, string(payload))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err := io.ReadAll(imgResp.Body)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "read image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err = ConvertImageToPNG(imgData)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "convert image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					mu.Lock()
 | 
			
		||||
					respBody.Data = append(respBody.Data, openai.ImageData{
 | 
			
		||||
						B64Json: fmt.Sprintf("data:image/png;base64,%s",
 | 
			
		||||
							base64.StdEncoding.EncodeToString(imgData)),
 | 
			
		||||
					})
 | 
			
		||||
					mu.Unlock()
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := pool.Wait(); err != nil {
 | 
			
		||||
				if len(respBody.Data) == 0 {
 | 
			
		||||
					return errors.WithStack(err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			c.JSON(http.StatusOK, respBody)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageToPNG converts a WebP image to PNG format
 | 
			
		||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
 | 
			
		||||
	// bypass if it's already a PNG image
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
 | 
			
		||||
		return webpData, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if is jpeg, convert to png
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
 | 
			
		||||
		img, _, err := image.Decode(bytes.NewReader(webpData))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "decode jpeg")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var pngBuffer bytes.Buffer
 | 
			
		||||
		if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return pngBuffer.Bytes(), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Decode the WebP image
 | 
			
		||||
	img, err := webp.Decode(bytes.NewReader(webpData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "decode webp")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode the image as PNG
 | 
			
		||||
	var pngBuffer bytes.Buffer
 | 
			
		||||
	if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return pngBuffer.Bytes(), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										111
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,111 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DrawImageRequest draw image by fluxpro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type DrawImageRequest struct {
 | 
			
		||||
	Input ImageInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
 | 
			
		||||
type ImageInput struct {
 | 
			
		||||
	Steps           int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt          string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	ImagePrompt     string `json:"image_prompt"`
 | 
			
		||||
	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	Interval        int    `json:"interval" binding:"required,min=1,max=4"`
 | 
			
		||||
	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
 | 
			
		||||
	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	Seed            int    `json:"seed"`
 | 
			
		||||
	NImages         int    `json:"n_images" binding:"required,min=1,max=8"`
 | 
			
		||||
	Width           int    `json:"width" binding:"required,min=256,max=1440"`
 | 
			
		||||
	Height          int    `json:"height" binding:"required,min=256,max=1440"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type InpaintingImageByFlusReplicateRequest struct {
 | 
			
		||||
	Input FluxInpaintingInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type FluxInpaintingInput struct {
 | 
			
		||||
	Mask             string `json:"mask" binding:"required"`
 | 
			
		||||
	Image            string `json:"image" binding:"required"`
 | 
			
		||||
	Seed             int    `json:"seed"`
 | 
			
		||||
	Steps            int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt           string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	OutputFormat     string `json:"output_format"`
 | 
			
		||||
	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	PromptUnsampling bool   `json:"prompt_unsampling"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageResponse is response of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	CompletedAt time.Time        `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time        `json:"created_at"`
 | 
			
		||||
	DataRemoved bool             `json:"data_removed"`
 | 
			
		||||
	Error       string           `json:"error"`
 | 
			
		||||
	ID          string           `json:"id"`
 | 
			
		||||
	Input       DrawImageRequest `json:"input"`
 | 
			
		||||
	Logs        string           `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics      `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    any       `json:"output"`
 | 
			
		||||
	StartedAt time.Time `json:"started_at"`
 | 
			
		||||
	Status    string    `json:"status"`
 | 
			
		||||
	URLs      FluxURLs  `json:"urls"`
 | 
			
		||||
	Version   string    `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ImageResponse) GetOutput() ([]string, error) {
 | 
			
		||||
	switch v := r.Output.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return []string{v}, nil
 | 
			
		||||
	case []string:
 | 
			
		||||
		return v, nil
 | 
			
		||||
	case nil:
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	case []interface{}:
 | 
			
		||||
		// convert []interface{} to []string
 | 
			
		||||
		ret := make([]string, len(v))
 | 
			
		||||
		for idx, vv := range v {
 | 
			
		||||
			if vvv, ok := vv.(string); ok {
 | 
			
		||||
				ret[idx] = vvv
 | 
			
		||||
			} else {
 | 
			
		||||
				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return ret, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxMetrics is metrics of ImageResponse
 | 
			
		||||
type FluxMetrics struct {
 | 
			
		||||
	ImageCount  int     `json:"image_count"`
 | 
			
		||||
	PredictTime float64 `json:"predict_time"`
 | 
			
		||||
	TotalTime   float64 `json:"total_time"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxURLs is urls of ImageResponse
 | 
			
		||||
type FluxURLs struct {
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
@@ -19,6 +19,7 @@ const (
 | 
			
		||||
	DeepL
 | 
			
		||||
	VertexAI
 | 
			
		||||
	Proxy
 | 
			
		||||
	Replicate
 | 
			
		||||
 | 
			
		||||
	Dummy // this one is only for count, do not add any channel after this
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -217,6 +217,31 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"deepl-ja": 25.0 / 1000 * USD,
 | 
			
		||||
	// https://console.x.ai/
 | 
			
		||||
	"grok-beta": 5.0 / 1000 * USD,
 | 
			
		||||
	// replicate charges based on the number of generated images
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro":                0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev":                    0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev-lora":               0.032 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-dev":               0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-pro":               0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-pro":                    0.055 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell":          0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell":                0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora":           0.02 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2":                       0.08 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3":                         0.04 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3-svg":                     0.08 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3":               0.035 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var CompletionRatio = map[string]float64{
 | 
			
		||||
 
 | 
			
		||||
@@ -47,5 +47,6 @@ const (
 | 
			
		||||
	Proxy
 | 
			
		||||
	SiliconFlow
 | 
			
		||||
	XAI
 | 
			
		||||
	Replicate
 | 
			
		||||
	Dummy
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
 | 
			
		||||
		apiType = apitype.DeepL
 | 
			
		||||
	case VertextAI:
 | 
			
		||||
		apiType = apitype.VertexAI
 | 
			
		||||
	case Replicate:
 | 
			
		||||
		apiType = apitype.Replicate
 | 
			
		||||
	case Proxy:
 | 
			
		||||
		apiType = apitype.Proxy
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                                          // 43
 | 
			
		||||
	"https://api.siliconflow.cn",                // 44
 | 
			
		||||
	"https://api.x.ai",                          // 45
 | 
			
		||||
	"https://api.replicate.com/v1/models/",      // 46
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
 
 | 
			
		||||
@@ -151,12 +151,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.Init(meta)
 | 
			
		||||
 | 
			
		||||
	// these adaptors need to convert the request
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Ali:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case channeltype.Baidu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case channeltype.Zhipu:
 | 
			
		||||
	case channeltype.Zhipu,
 | 
			
		||||
		channeltype.Ali,
 | 
			
		||||
		channeltype.Replicate,
 | 
			
		||||
		channeltype.Baidu:
 | 
			
		||||
		finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
@@ -189,7 +189,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		if resp != nil && resp.StatusCode != http.StatusOK {
 | 
			
		||||
		if resp != nil &&
 | 
			
		||||
			resp.StatusCode != http.StatusCreated && // replicate returns 201
 | 
			
		||||
			resp.StatusCode != http.StatusOK {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -208,6 +210,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
 | 
			
		||||
			channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
			model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
 | 
			
		||||
			// also update user request cost
 | 
			
		||||
			docu := model.NewUserRequestCost(
 | 
			
		||||
				c.GetInt(ctxkey.Id),
 | 
			
		||||
				c.GetString(ctxkey.RequestId),
 | 
			
		||||
				quota,
 | 
			
		||||
			)
 | 
			
		||||
			if err = docu.Insert(); err != nil {
 | 
			
		||||
				logger.Errorf(c, "insert user request cost failed: %+v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 43, text: 'Proxy', value: 43, color: 'blue' },
 | 
			
		||||
  { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
 | 
			
		||||
  { key: 45, text: 'xAI', value: 45, color: 'blue' },
 | 
			
		||||
  { key: 46, text: 'Replicate', value: 46, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
 | 
			
		||||
    value: 45,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  45: {
 | 
			
		||||
    key: 46,
 | 
			
		||||
    text: 'Replicate',
 | 
			
		||||
    value: 46,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  41: {
 | 
			
		||||
    key: 41,
 | 
			
		||||
    text: 'Novita',
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
    { key: 43, text: 'Proxy', value: 43, color: 'blue' },
 | 
			
		||||
    { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
 | 
			
		||||
    { key: 45, text: 'xAI', value: 45, color: 'blue' },
 | 
			
		||||
    { key: 46, text: 'Replicate', value: 46, color: 'blue' },
 | 
			
		||||
    { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
    { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
    { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user