mirror of
https://github.com/songquanpeng/one-api.git
synced 2026-04-06 02:14:25 +08:00
Merge d2bc9eb5ae into 36c8f4f15c
This commit is contained in:
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -1,13 +1,13 @@
|
||||
name: CI
|
||||
|
||||
# This setup assumes that you run the unit tests with code coverage in the same
|
||||
# workflow that will also print the coverage report as comment to the pull request.
|
||||
# workflow that will also print the coverage report as comment to the pull request.
|
||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
|
||||
# when new code is pushed to the branch of the pull request. In addition, you also
|
||||
# need to trigger this workflow when new code is pushed to the main branch because
|
||||
# need to trigger this workflow when new code is pushed to the main branch because
|
||||
# we need to upload the code coverage results as artifact for the main branch as
|
||||
# well since it will be the baseline code coverage.
|
||||
#
|
||||
#
|
||||
# We do not want to trigger the workflow for pushes to *any* branch because this
|
||||
# would trigger our jobs twice on pull requests (once from "push" event and once
|
||||
# from "pull_request->synchronize")
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
go-version: ^1.22
|
||||
|
||||
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
|
||||
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
|
||||
# in the next step as well as the next job.
|
||||
- name: Test
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,4 +9,5 @@ logs
|
||||
data
|
||||
/web/node_modules
|
||||
cmd.md
|
||||
.env
|
||||
.env
|
||||
/one-api
|
||||
|
||||
@@ -3,6 +3,7 @@ package ctxkey
|
||||
const (
|
||||
Config = "config"
|
||||
Id = "id"
|
||||
RequestId = "X-Oneapi-Request-Id"
|
||||
Username = "username"
|
||||
Role = "role"
|
||||
Status = "status"
|
||||
@@ -15,6 +16,7 @@ const (
|
||||
Group = "group"
|
||||
ModelMapping = "model_mapping"
|
||||
ChannelName = "channel_name"
|
||||
ContentType = "content_type"
|
||||
TokenId = "token_id"
|
||||
TokenName = "token_name"
|
||||
BaseURL = "base_url"
|
||||
|
||||
@@ -3,24 +3,27 @@ package common
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
)
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(ctxkey.KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
func GetRequestBody(c *gin.Context) (requestBody []byte, err error) {
|
||||
if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil {
|
||||
return requestBodyCache.([]byte), nil
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
requestBody, err = io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(ctxkey.KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
@@ -28,18 +31,25 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check v should be a pointer
|
||||
if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr {
|
||||
return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
err = json.Unmarshal(requestBody, v)
|
||||
} else {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
err = c.ShouldBind(&v)
|
||||
err = c.ShouldBind(v)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "unmarshal request body failed")
|
||||
}
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ import (
|
||||
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
var err *model.ErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relaymode.ImagesGenerations:
|
||||
case relaymode.ImagesGenerations,
|
||||
relaymode.ImagesEdits:
|
||||
err = controller.RelayImageHelper(c, relayMode)
|
||||
case relaymode.AudioSpeech:
|
||||
fallthrough
|
||||
@@ -45,10 +46,6 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
func Relay(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
relayMode := relaymode.GetByPath(c.Request.URL.Path)
|
||||
if config.DebugEnabled {
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
logger.Debugf(ctx, "request body: %s", string(requestBody))
|
||||
}
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
bizErr := relayHelper(c, relayMode)
|
||||
@@ -60,7 +57,7 @@ func Relay(c *gin.Context) {
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
group := c.GetString(ctxkey.Group)
|
||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||
requestId := c.GetString(helper.RequestIdKey)
|
||||
retryTimes := config.RetryTimes
|
||||
if !shouldRetry(c, bizErr.StatusCode) {
|
||||
@@ -87,9 +84,9 @@ func Relay(c *gin.Context) {
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
lastFailedChannelId = channelId
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
// BUG: bizErr is in race condition
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||
}
|
||||
|
||||
if bizErr != nil {
|
||||
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
@@ -122,7 +119,10 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
||||
func processChannelRelayError(ctx context.Context,
|
||||
userId int, channelId int, channelName string,
|
||||
// FIX: err should not use a pointer to avoid data race in concurrent situations
|
||||
err model.ErrorWithStatusCode) {
|
||||
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -27,6 +27,7 @@ require (
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/crypto v0.24.0
|
||||
golang.org/x/image v0.18.0
|
||||
golang.org/x/sync v0.7.0
|
||||
google.golang.org/api v0.187.0
|
||||
gorm.io/driver/mysql v1.5.6
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
@@ -99,7 +100,6 @@ require (
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/net v0.26.0 // indirect
|
||||
golang.org/x/oauth2 v0.21.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.21.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
|
||||
@@ -64,6 +64,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
|
||||
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
|
||||
}
|
||||
c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type"))
|
||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
@@ -24,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) {
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||
return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed")
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"):
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
case strings.HasSuffix(c.Request.URL.Path, "embeddings"):
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"),
|
||||
strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"):
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"),
|
||||
strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"):
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
|
||||
return modelRequest.Model, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||
strings.Contains(lowerMessage, "credit") ||
|
||||
strings.Contains(lowerMessage, "balance") ||
|
||||
strings.Contains(lowerMessage, "permission denied") ||
|
||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||
strings.Contains(lowerMessage, "已欠费") {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
||||
return &vertexai.Adaptor{}
|
||||
case apitype.Proxy:
|
||||
return &proxy.Adaptor{}
|
||||
case apitype.Replicate:
|
||||
return &replicate.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@ package adaptor
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
)
|
||||
|
||||
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
||||
@@ -27,6 +29,9 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType))
|
||||
|
||||
err = a.SetupRequestHeader(c, req, meta)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
|
||||
@@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
NumPredict: request.MaxTokens,
|
||||
NumCtx: request.NumCtx,
|
||||
NumPredict: request.MaxTokens,
|
||||
NumCtx: request.NumCtx,
|
||||
},
|
||||
Stream: request.Stream,
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if strings.HasPrefix(data, "}") {
|
||||
data = strings.TrimPrefix(data, "}") + "}"
|
||||
data = strings.TrimPrefix(data, "}") + "}"
|
||||
}
|
||||
|
||||
var ollamaResponse ChatResponse
|
||||
|
||||
@@ -111,10 +111,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, _ = ImageHandler(c, resp)
|
||||
case relaymode.ImagesEdits:
|
||||
err, _ = ImagesEditsHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,30 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
// ImagesEditsHandler just copy response body to client
|
||||
//
|
||||
// https://platform.openai.com/docs/api-reference/images/createEdit
|
||||
func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var imageResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
package openai
|
||||
|
||||
import "github.com/songquanpeng/one-api/relay/model"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||
|
||||
Error := model.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
|
||||
128
relay/adaptor/replicate/adaptor.go
Normal file
128
relay/adaptor/replicate/adaptor.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
meta *meta.Meta
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("should call replicate.ConvertImageRequest instead")
|
||||
}
|
||||
|
||||
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
if request.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("only support b64_json response format")
|
||||
}
|
||||
if request.N != 1 && request.N != 0 {
|
||||
return nil, errors.New("only support N=1")
|
||||
}
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return convertImageCreateRequest(request)
|
||||
case relaymode.ImagesEdits:
|
||||
return convertImageRemixRequest(c)
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||
return DrawImageRequest{
|
||||
Input: ImageInput{
|
||||
Steps: 25,
|
||||
Prompt: request.Prompt,
|
||||
Guidance: 3,
|
||||
Seed: int(time.Now().UnixNano()),
|
||||
SafetyTolerance: 5,
|
||||
NImages: 1, // replicate will always return 1 image
|
||||
Width: 1440,
|
||||
Height: 1440,
|
||||
AspectRatio: "1:1",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertImageRemixRequest(c *gin.Context) (any, error) {
|
||||
// recover request body
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request body")
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
|
||||
rawReq := new(OpenaiImageEditRequest)
|
||||
if err := c.ShouldBind(rawReq); err != nil {
|
||||
return nil, errors.Wrap(err, "parse image edit form")
|
||||
}
|
||||
|
||||
return rawReq.toFluxRemixRequest()
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
a.meta = meta
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
logger.Info(c, "send image request to replicate")
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations,
|
||||
relaymode.ImagesEdits:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
default:
|
||||
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "replicate"
|
||||
}
|
||||
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package replicate
|
||||
|
||||
// ModelList is a list of models that can be used with Replicate.
|
||||
//
|
||||
// https://replicate.com/pricing
|
||||
var ModelList = []string{
|
||||
// -------------------------------------
|
||||
// image model
|
||||
// -------------------------------------
|
||||
"black-forest-labs/flux-1.1-pro",
|
||||
"black-forest-labs/flux-1.1-pro-ultra",
|
||||
"black-forest-labs/flux-canny-dev",
|
||||
"black-forest-labs/flux-canny-pro",
|
||||
"black-forest-labs/flux-depth-dev",
|
||||
"black-forest-labs/flux-depth-pro",
|
||||
"black-forest-labs/flux-dev",
|
||||
"black-forest-labs/flux-dev-lora",
|
||||
"black-forest-labs/flux-fill-dev",
|
||||
"black-forest-labs/flux-fill-pro",
|
||||
"black-forest-labs/flux-pro",
|
||||
"black-forest-labs/flux-redux-dev",
|
||||
"black-forest-labs/flux-redux-schnell",
|
||||
"black-forest-labs/flux-schnell",
|
||||
"black-forest-labs/flux-schnell-lora",
|
||||
"ideogram-ai/ideogram-v2",
|
||||
"ideogram-ai/ideogram-v2-turbo",
|
||||
"recraft-ai/recraft-v3",
|
||||
"recraft-ai/recraft-v3-svg",
|
||||
"stability-ai/stable-diffusion-3",
|
||||
"stability-ai/stable-diffusion-3.5-large",
|
||||
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||
"stability-ai/stable-diffusion-3.5-medium",
|
||||
// -------------------------------------
|
||||
// language model
|
||||
// -------------------------------------
|
||||
// "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor
|
||||
// "meta/llama-2-13b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-13b-chat", // TODO: implement the adaptor
|
||||
// "meta/llama-2-70b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-70b-chat", // TODO: implement the adaptor
|
||||
// "meta/llama-2-7b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-7b-chat", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-70b", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-8b", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor
|
||||
// "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor
|
||||
// "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor
|
||||
// "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor
|
||||
// -------------------------------------
|
||||
// video model
|
||||
// -------------------------------------
|
||||
// "minimax/video-01", // TODO: implement the adaptor
|
||||
}
|
||||
223
relay/adaptor/replicate/image.go
Normal file
223
relay/adaptor/replicate/image.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"golang.org/x/image/webp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// // ImagesEditsHandler just copy response body to client
|
||||
// //
|
||||
// // https://replicate.com/black-forest-labs/flux-fill-pro
|
||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
// c.Writer.WriteHeader(resp.StatusCode)
|
||||
// for k, v := range resp.Header {
|
||||
// c.Writer.Header().Set(k, v[0])
|
||||
// }
|
||||
|
||||
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
// return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
var errNextLoop = errors.New("next_loop")
|
||||
|
||||
func ImageHandler(c *gin.Context, resp *http.Response) (
|
||||
*model.ErrorWithStatusCode, *model.Usage) {
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return openai.ErrorWrapper(
|
||||
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||
"bad_status_code", http.StatusInternalServerError),
|
||||
nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
respData := new(ImageResponse)
|
||||
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
for {
|
||||
err = func() error {
|
||||
// get task
|
||||
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, respData.URLs.Get, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get task")
|
||||
}
|
||||
defer taskResp.Body.Close()
|
||||
|
||||
if taskResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(taskResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
taskResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
taskBody, err := io.ReadAll(taskResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read task response")
|
||||
}
|
||||
|
||||
taskData := new(ImageResponse)
|
||||
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||
return errors.Wrap(err, "decode task response")
|
||||
}
|
||||
|
||||
switch taskData.Status {
|
||||
case "succeeded":
|
||||
case "failed", "canceled":
|
||||
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||
default:
|
||||
time.Sleep(time.Second * 3)
|
||||
return errNextLoop
|
||||
}
|
||||
|
||||
output, err := taskData.GetOutput()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get output")
|
||||
}
|
||||
if len(output) == 0 {
|
||||
return errors.New("response output is empty")
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var pool errgroup.Group
|
||||
respBody := &openai.ImageResponse{
|
||||
Created: taskData.CompletedAt.Unix(),
|
||||
Data: []openai.ImageData{},
|
||||
}
|
||||
|
||||
for _, imgOut := range output {
|
||||
imgOut := imgOut
|
||||
pool.Go(func() error {
|
||||
// download image
|
||||
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, imgOut, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "download image")
|
||||
}
|
||||
defer imgResp.Body.Close()
|
||||
|
||||
if imgResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(imgResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
imgResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
imgData, err := io.ReadAll(imgResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read image")
|
||||
}
|
||||
|
||||
imgData, err = ConvertImageToPNG(imgData)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "convert image")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||
base64.StdEncoding.EncodeToString(imgData)),
|
||||
})
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := pool.Wait(); err != nil {
|
||||
if len(respBody.Data) == 0 {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, respBody)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
if errors.Is(err, errNextLoop) {
|
||||
continue
|
||||
}
|
||||
|
||||
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ConvertImageToPNG converts a WebP image to PNG format
|
||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||
// bypass if it's already a PNG image
|
||||
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||
return webpData, nil
|
||||
}
|
||||
|
||||
// check if is jpeg, convert to png
|
||||
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode jpeg")
|
||||
}
|
||||
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// Decode the WebP image
|
||||
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode webp")
|
||||
}
|
||||
|
||||
// Encode the image as PNG
|
||||
var pngBuffer bytes.Buffer
|
||||
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||
return nil, errors.Wrap(err, "encode png")
|
||||
}
|
||||
|
||||
return pngBuffer.Bytes(), nil
|
||||
}
|
||||
229
relay/adaptor/replicate/model.go
Normal file
229
relay/adaptor/replicate/model.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"image"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type OpenaiImageEditRequest struct {
|
||||
Image *multipart.FileHeader `json:"image" form:"image" binding:"required"`
|
||||
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
||||
Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"`
|
||||
Model string `json:"model" form:"model" binding:"required"`
|
||||
N int `json:"n" form:"n" binding:"min=0,max=10"`
|
||||
Size string `json:"size" form:"size"`
|
||||
ResponseFormat string `json:"response_format" form:"response_format"`
|
||||
}
|
||||
|
||||
// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request.
|
||||
//
|
||||
// Note that the mask formats of OpenAI and Flux are different:
|
||||
// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0),
|
||||
// while Flux sets the parts to be modified as black (255, 255, 255, 255),
|
||||
// so we need to convert the format here.
|
||||
//
|
||||
// Both OpenAI's Image and Mask are browser-native ImageData,
|
||||
// which need to be converted to base64 dataURI format.
|
||||
func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) {
|
||||
if r.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("response_format must be b64_json for replicate models")
|
||||
}
|
||||
|
||||
fluxReq := &InpaintingImageByFlusReplicateRequest{
|
||||
Input: FluxInpaintingInput{
|
||||
Prompt: r.Prompt,
|
||||
Seed: int(time.Now().UnixNano()),
|
||||
Steps: 30,
|
||||
Guidance: 3,
|
||||
SafetyTolerance: 5,
|
||||
PromptUnsampling: false,
|
||||
OutputFormat: "png",
|
||||
},
|
||||
}
|
||||
|
||||
imgFile, err := r.Image.Open()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open image file")
|
||||
}
|
||||
defer imgFile.Close()
|
||||
imgData, err := io.ReadAll(imgFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read image file")
|
||||
}
|
||||
|
||||
maskFile, err := r.Mask.Open()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open mask file")
|
||||
}
|
||||
defer maskFile.Close()
|
||||
|
||||
// Convert image to base64
|
||||
imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData)
|
||||
fluxReq.Input.Image = imageBase64
|
||||
|
||||
// Convert mask data to RGBA
|
||||
maskPNG, err := png.Decode(maskFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode mask file")
|
||||
}
|
||||
|
||||
// convert mask to RGBA
|
||||
var maskRGBA *image.RGBA
|
||||
switch converted := maskPNG.(type) {
|
||||
case *image.RGBA:
|
||||
maskRGBA = converted
|
||||
default:
|
||||
// Convert to RGBA
|
||||
bounds := maskPNG.Bounds()
|
||||
maskRGBA = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
maskRGBA.Set(x, y, maskPNG.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maskData := maskRGBA.Pix
|
||||
invertedMask := make([]byte, len(maskData))
|
||||
for i := 0; i+4 <= len(maskData); i += 4 {
|
||||
// If pixel is transparent (alpha = 0), make it black (255)
|
||||
if maskData[i+3] == 0 {
|
||||
invertedMask[i] = 255 // R
|
||||
invertedMask[i+1] = 255 // G
|
||||
invertedMask[i+2] = 255 // B
|
||||
invertedMask[i+3] = 255 // A
|
||||
} else {
|
||||
// Copy original pixel
|
||||
copy(invertedMask[i:i+4], maskData[i:i+4])
|
||||
}
|
||||
}
|
||||
|
||||
// Convert inverted mask to base64 encoded png image
|
||||
invertedMaskRGBA := &image.RGBA{
|
||||
Pix: invertedMask,
|
||||
Stride: maskRGBA.Stride,
|
||||
Rect: maskRGBA.Rect,
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = png.Encode(&buf, invertedMaskRGBA)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "encode inverted mask to png")
|
||||
}
|
||||
|
||||
invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
fluxReq.Input.Mask = invertedMaskBase64
|
||||
|
||||
return fluxReq, nil
|
||||
}
|
||||
|
||||
// DrawImageRequest draw image by fluxpro
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||
type DrawImageRequest struct {
|
||||
Input ImageInput `json:"input"`
|
||||
}
|
||||
|
||||
// ImageInput is input of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||
type ImageInput struct {
|
||||
Steps int `json:"steps" binding:"required,min=1"`
|
||||
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||
ImagePrompt string `json:"image_prompt"`
|
||||
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||
Seed int `json:"seed"`
|
||||
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||
}
|
||||
|
||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||
type InpaintingImageByFlusReplicateRequest struct {
|
||||
Input FluxInpaintingInput `json:"input"`
|
||||
}
|
||||
|
||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||
type FluxInpaintingInput struct {
|
||||
Mask string `json:"mask" binding:"required"`
|
||||
Image string `json:"image" binding:"required"`
|
||||
Seed int `json:"seed"`
|
||||
Steps int `json:"steps" binding:"required,min=1"`
|
||||
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||
OutputFormat string `json:"output_format"`
|
||||
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||
}
|
||||
|
||||
// ImageResponse is response of DrawImageByFluxProRequest
|
||||
//
|
||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||
type ImageResponse struct {
|
||||
CompletedAt time.Time `json:"completed_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DataRemoved bool `json:"data_removed"`
|
||||
Error string `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Input DrawImageRequest `json:"input"`
|
||||
Logs string `json:"logs"`
|
||||
Metrics FluxMetrics `json:"metrics"`
|
||||
// Output could be `string` or `[]string`
|
||||
Output any `json:"output"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Status string `json:"status"`
|
||||
URLs FluxURLs `json:"urls"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||
switch v := r.Output.(type) {
|
||||
case string:
|
||||
return []string{v}, nil
|
||||
case []string:
|
||||
return v, nil
|
||||
case nil:
|
||||
return nil, nil
|
||||
case []interface{}:
|
||||
// convert []interface{} to []string
|
||||
ret := make([]string, len(v))
|
||||
for idx, vv := range v {
|
||||
if vvv, ok := vv.(string); ok {
|
||||
ret[idx] = vvv
|
||||
} else {
|
||||
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// FluxMetrics is metrics of ImageResponse
|
||||
type FluxMetrics struct {
|
||||
ImageCount int `json:"image_count"`
|
||||
PredictTime float64 `json:"predict_time"`
|
||||
TotalTime float64 `json:"total_time"`
|
||||
}
|
||||
|
||||
// FluxURLs is urls of ImageResponse
|
||||
type FluxURLs struct {
|
||||
Get string `json:"get"`
|
||||
Cancel string `json:"cancel"`
|
||||
}
|
||||
106
relay/adaptor/replicate/model_test.go
Normal file
106
relay/adaptor/replicate/model_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type nopCloser struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (n nopCloser) Close() error { return nil }
|
||||
|
||||
// Custom FileHeader to override Open method
|
||||
type customFileHeader struct {
|
||||
*multipart.FileHeader
|
||||
openFunc func() (multipart.File, error)
|
||||
}
|
||||
|
||||
func (c *customFileHeader) Open() (multipart.File, error) {
|
||||
return c.openFunc()
|
||||
}
|
||||
|
||||
func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) {
|
||||
// Create a simple image for testing
|
||||
img := image.NewRGBA(image.Rect(0, 0, 10, 10))
|
||||
draw.Draw(img, img.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src)
|
||||
var imgBuf bytes.Buffer
|
||||
err := png.Encode(&imgBuf, img)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a simple mask for testing
|
||||
mask := image.NewRGBA(image.Rect(0, 0, 10, 10))
|
||||
draw.Draw(mask, mask.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src)
|
||||
var maskBuf bytes.Buffer
|
||||
err = png.Encode(&maskBuf, mask)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a multipart.FileHeader from the image and mask bytes
|
||||
imgFileHeader, err := createFileHeader("image", "test.png", imgBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes())
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &OpenaiImageEditRequest{
|
||||
Image: imgFileHeader,
|
||||
Mask: maskFileHeader,
|
||||
Prompt: "Test prompt",
|
||||
Model: "test-model",
|
||||
ResponseFormat: "b64_json",
|
||||
}
|
||||
|
||||
fluxReq, err := req.toFluxRemixRequest()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fluxReq)
|
||||
require.Equal(t, req.Prompt, fluxReq.Input.Prompt)
|
||||
require.NotEmpty(t, fluxReq.Input.Image)
|
||||
require.NotEmpty(t, fluxReq.Input.Mask)
|
||||
}
|
||||
|
||||
// createFileHeader creates a multipart.FileHeader from file bytes
|
||||
func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
// Create a form file field
|
||||
part, err := writer.CreateFormFile(fieldname, filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the file bytes to the form file field
|
||||
_, err = part.Write(fileBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Close the writer to finalize the form
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the multipart form
|
||||
req := &http.Request{
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(body),
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
err = req.ParseMultipartForm(int64(body.Len()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Retrieve the file header from the parsed form
|
||||
fileHeader := req.MultipartForm.File[fieldname][0]
|
||||
return fileHeader, nil
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
|
||||
@@ -19,6 +19,7 @@ const (
|
||||
DeepL
|
||||
VertexAI
|
||||
Proxy
|
||||
Replicate
|
||||
|
||||
Dummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -211,6 +211,31 @@ var ModelRatio = map[string]float64{
|
||||
"deepl-ja": 25.0 / 1000 * USD,
|
||||
// https://console.x.ai/
|
||||
"grok-beta": 5.0 / 1000 * USD,
|
||||
// replicate charges based on the number of generated images
|
||||
// https://replicate.com/pricing
|
||||
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
|
||||
"black-forest-labs/flux-canny-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-canny-pro": 0.05 * USD,
|
||||
"black-forest-labs/flux-depth-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-depth-pro": 0.05 * USD,
|
||||
"black-forest-labs/flux-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-dev-lora": 0.032 * USD,
|
||||
"black-forest-labs/flux-fill-dev": 0.04 * USD,
|
||||
"black-forest-labs/flux-fill-pro": 0.05 * USD,
|
||||
"black-forest-labs/flux-pro": 0.055 * USD,
|
||||
"black-forest-labs/flux-redux-dev": 0.025 * USD,
|
||||
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
|
||||
"black-forest-labs/flux-schnell": 0.003 * USD,
|
||||
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
|
||||
"ideogram-ai/ideogram-v2": 0.08 * USD,
|
||||
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
|
||||
"recraft-ai/recraft-v3": 0.04 * USD,
|
||||
"recraft-ai/recraft-v3-svg": 0.08 * USD,
|
||||
"stability-ai/stable-diffusion-3": 0.035 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||
}
|
||||
|
||||
var CompletionRatio = map[string]float64{
|
||||
|
||||
@@ -47,5 +47,6 @@ const (
|
||||
Proxy
|
||||
SiliconFlow
|
||||
XAI
|
||||
Replicate
|
||||
Dummy
|
||||
)
|
||||
|
||||
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
|
||||
apiType = apitype.DeepL
|
||||
case VertextAI:
|
||||
apiType = apitype.VertexAI
|
||||
case Replicate:
|
||||
apiType = apitype.Replicate
|
||||
case Proxy:
|
||||
apiType = apitype.Proxy
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
|
||||
"", // 43
|
||||
"https://api.siliconflow.cn", // 44
|
||||
"https://api.x.ai", // 45
|
||||
"https://api.replicate.com/v1/models/", // 46
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -4,18 +4,20 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
@@ -26,7 +28,7 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e
|
||||
imageRequest := &relaymodel.ImageRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
@@ -134,7 +136,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
c.Set("response_format", imageRequest.ResponseFormat)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
|
||||
if strings.ToLower(c.GetString(ctxkey.ContentType)) == "application/json" &&
|
||||
isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||
@@ -150,12 +153,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
|
||||
// these adaptors need to convert the request
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Ali:
|
||||
fallthrough
|
||||
case channeltype.Baidu:
|
||||
fallthrough
|
||||
case channeltype.Zhipu:
|
||||
case channeltype.Zhipu,
|
||||
channeltype.Ali,
|
||||
channeltype.Baidu:
|
||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||
@@ -165,6 +167,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case channeltype.Replicate:
|
||||
finalRequest, err := replicate.ConvertImageRequest(c, imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonStr, err := json.Marshal(finalRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
|
||||
@@ -172,7 +184,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
var quota int64
|
||||
switch meta.ChannelType {
|
||||
case channeltype.Replicate:
|
||||
// replicate always return 1 image
|
||||
quota = int64(ratio * imageCostRatio * 1000)
|
||||
default:
|
||||
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||
}
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
@@ -186,7 +205,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
if resp != nil &&
|
||||
resp.StatusCode != http.StatusCreated && // replicate returns 201
|
||||
resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package model
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Model string `json:"model" form:"model"`
|
||||
Prompt string `json:"prompt" form:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty" form:"n"`
|
||||
Size string `json:"size,omitempty" form:"size"`
|
||||
Quality string `json:"quality,omitempty" form:"quality"`
|
||||
ResponseFormat string `json:"response_format,omitempty" form:"response_format"`
|
||||
Style string `json:"style,omitempty" form:"style"`
|
||||
User string `json:"user,omitempty" form:"user"`
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ const (
|
||||
AudioSpeech
|
||||
AudioTranscription
|
||||
AudioTranslation
|
||||
ImagesEdits
|
||||
// Proxy is a special relay mode for proxying requests to custom upstream
|
||||
Proxy
|
||||
)
|
||||
|
||||
@@ -24,8 +24,11 @@ func GetByPath(path string) int {
|
||||
relayMode = AudioTranscription
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
relayMode = AudioTranslation
|
||||
} else if strings.HasPrefix(path, "/v1/images/edits") {
|
||||
relayMode = ImagesEdits
|
||||
} else if strings.HasPrefix(path, "/v1/oneapi/proxy") {
|
||||
relayMode = Proxy
|
||||
}
|
||||
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
relayV1Router.POST("/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/generations", controller.Relay)
|
||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
|
||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
|
||||
@@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 45,
|
||||
color: 'primary'
|
||||
},
|
||||
45: {
|
||||
key: 46,
|
||||
text: 'Replicate',
|
||||
value: 46,
|
||||
color: 'primary'
|
||||
},
|
||||
41: {
|
||||
key: 41,
|
||||
text: 'Novita',
|
||||
|
||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
|
||||
Reference in New Issue
Block a user