mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: support replicate chat models (#1989)
* feat: add Replicate adaptor and integrate into channel and API types * feat: support llm chat on replicate
This commit is contained in:
		
							
								
								
									
										8
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,13 +1,13 @@
 | 
			
		||||
name: CI
 | 
			
		||||
 | 
			
		||||
# This setup assumes that you run the unit tests with code coverage in the same
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request. 
 | 
			
		||||
# workflow that will also print the coverage report as comment to the pull request.
 | 
			
		||||
# Therefore, you need to trigger this workflow when a pull request is (re)opened or
 | 
			
		||||
# when new code is pushed to the branch of the pull request. In addition, you also
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because 
 | 
			
		||||
# need to trigger this workflow when new code is pushed to the main branch because
 | 
			
		||||
# we need to upload the code coverage results as artifact for the main branch as
 | 
			
		||||
# well since it will be the baseline code coverage.
 | 
			
		||||
# 
 | 
			
		||||
#
 | 
			
		||||
# We do not want to trigger the workflow for pushes to *any* branch because this
 | 
			
		||||
# would trigger our jobs twice on pull requests (once from "push" event and once
 | 
			
		||||
# from "pull_request->synchronize")
 | 
			
		||||
@@ -31,7 +31,7 @@ jobs:
 | 
			
		||||
        with:
 | 
			
		||||
          go-version: ^1.22
 | 
			
		||||
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a 
 | 
			
		||||
      # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
 | 
			
		||||
      # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
 | 
			
		||||
      # in the next step as well as the next job.
 | 
			
		||||
      - name: Test
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -9,4 +9,5 @@ logs
 | 
			
		||||
data
 | 
			
		||||
/web/node_modules
 | 
			
		||||
cmd.md
 | 
			
		||||
.env
 | 
			
		||||
.env
 | 
			
		||||
/one-api
 | 
			
		||||
 
 | 
			
		||||
@@ -3,9 +3,10 @@ package render
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func StringData(c *gin.Context, str string) {
 | 
			
		||||
 
 | 
			
		||||
@@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
 | 
			
		||||
		strings.Contains(lowerMessage, "credit") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "balance") ||
 | 
			
		||||
		strings.Contains(lowerMessage, "permission denied") ||
 | 
			
		||||
  	strings.Contains(lowerMessage, "organization has been restricted") || // groq
 | 
			
		||||
		strings.Contains(lowerMessage, "organization has been restricted") || // groq
 | 
			
		||||
		strings.Contains(lowerMessage, "已欠费") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/palm"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/proxy"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/replicate"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/tencent"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
 | 
			
		||||
@@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
 | 
			
		||||
		return &vertexai.Adaptor{}
 | 
			
		||||
	case apitype.Proxy:
 | 
			
		||||
		return &proxy.Adaptor{}
 | 
			
		||||
	case apitype.Replicate:
 | 
			
		||||
		return &replicate.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
			TopP:             request.TopP,
 | 
			
		||||
			FrequencyPenalty: request.FrequencyPenalty,
 | 
			
		||||
			PresencePenalty:  request.PresencePenalty,
 | 
			
		||||
			NumPredict:  	  request.MaxTokens,
 | 
			
		||||
			NumCtx:  	  request.NumCtx,
 | 
			
		||||
			NumPredict:       request.MaxTokens,
 | 
			
		||||
			NumCtx:           request.NumCtx,
 | 
			
		||||
		},
 | 
			
		||||
		Stream: request.Stream,
 | 
			
		||||
	}
 | 
			
		||||
@@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if strings.HasPrefix(data, "}") {
 | 
			
		||||
		    data = strings.TrimPrefix(data, "}") + "}"
 | 
			
		||||
			data = strings.TrimPrefix(data, "}") + "}"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var ollamaResponse ChatResponse
 | 
			
		||||
 
 | 
			
		||||
@@ -2,15 +2,16 @@ package openai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
 | 
			
		||||
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
 | 
			
		||||
	usage := &model.Usage{}
 | 
			
		||||
	usage.PromptTokens = promptTokens
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modeName)
 | 
			
		||||
	usage.CompletionTokens = CountTokenText(responseText, modelName)
 | 
			
		||||
	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 | 
			
		||||
	return usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,16 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
 | 
			
		||||
	logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
 | 
			
		||||
 | 
			
		||||
	Error := model.Error{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								relay/adaptor/replicate/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,136 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
	meta *meta.Meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageRequest implements adaptor.Adaptor.
 | 
			
		||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
 | 
			
		||||
	return DrawImageRequest{
 | 
			
		||||
		Input: ImageInput{
 | 
			
		||||
			Steps:           25,
 | 
			
		||||
			Prompt:          request.Prompt,
 | 
			
		||||
			Guidance:        3,
 | 
			
		||||
			Seed:            int(time.Now().UnixNano()),
 | 
			
		||||
			SafetyTolerance: 5,
 | 
			
		||||
			NImages:         1, // replicate will always return 1 image
 | 
			
		||||
			Width:           1440,
 | 
			
		||||
			Height:          1440,
 | 
			
		||||
			AspectRatio:     "1:1",
 | 
			
		||||
		},
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if !request.Stream {
 | 
			
		||||
		// TODO: support non-stream mode
 | 
			
		||||
		return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Build the prompt from OpenAI messages
 | 
			
		||||
	var promptBuilder strings.Builder
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		switch msgCnt := message.Content.(type) {
 | 
			
		||||
		case string:
 | 
			
		||||
			promptBuilder.WriteString(message.Role)
 | 
			
		||||
			promptBuilder.WriteString(": ")
 | 
			
		||||
			promptBuilder.WriteString(msgCnt)
 | 
			
		||||
			promptBuilder.WriteString("\n")
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replicateRequest := ReplicateChatRequest{
 | 
			
		||||
		Input: ChatInput{
 | 
			
		||||
			Prompt:           promptBuilder.String(),
 | 
			
		||||
			MaxTokens:        request.MaxTokens,
 | 
			
		||||
			Temperature:      1.0,
 | 
			
		||||
			TopP:             1.0,
 | 
			
		||||
			PresencePenalty:  0.0,
 | 
			
		||||
			FrequencyPenalty: 0.0,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Map optional fields
 | 
			
		||||
	if request.Temperature != nil {
 | 
			
		||||
		replicateRequest.Input.Temperature = *request.Temperature
 | 
			
		||||
	}
 | 
			
		||||
	if request.TopP != nil {
 | 
			
		||||
		replicateRequest.Input.TopP = *request.TopP
 | 
			
		||||
	}
 | 
			
		||||
	if request.PresencePenalty != nil {
 | 
			
		||||
		replicateRequest.Input.PresencePenalty = *request.PresencePenalty
 | 
			
		||||
	}
 | 
			
		||||
	if request.FrequencyPenalty != nil {
 | 
			
		||||
		replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
 | 
			
		||||
	}
 | 
			
		||||
	if request.MaxTokens > 0 {
 | 
			
		||||
		replicateRequest.Input.MaxTokens = request.MaxTokens
 | 
			
		||||
	} else if request.MaxTokens == 0 {
 | 
			
		||||
		replicateRequest.Input.MaxTokens = 500
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return replicateRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *meta.Meta) {
 | 
			
		||||
	a.meta = meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
 | 
			
		||||
	if !slices.Contains(ModelList, meta.OriginModelName) {
 | 
			
		||||
		return "", errors.Errorf("model %s not supported", meta.OriginModelName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
 | 
			
		||||
	adaptor.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+meta.APIKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	logger.Info(c, "send request to replicate")
 | 
			
		||||
	return adaptor.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case relaymode.ImagesGenerations:
 | 
			
		||||
		err, usage = ImageHandler(c, resp)
 | 
			
		||||
	case relaymode.ChatCompletions:
 | 
			
		||||
		err, usage = ChatHandler(c, resp)
 | 
			
		||||
	default:
 | 
			
		||||
		err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return "replicate"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								relay/adaptor/replicate/chat.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,191 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ChatHandler(c *gin.Context, resp *http.Response) (
 | 
			
		||||
	srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ChatResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ChatResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if taskData.URLs.Stream == "" {
 | 
			
		||||
				return errors.New("stream url is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// request stream url
 | 
			
		||||
			responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "chat stream handler")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ctxMeta := meta.GetByContext(c)
 | 
			
		||||
			usage = openai.ResponseText2Usage(responseText,
 | 
			
		||||
				ctxMeta.ActualModelName, ctxMeta.PromptTokens)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, errNextLoop) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	eventPrefix = "event: "
 | 
			
		||||
	dataPrefix  = "data: "
 | 
			
		||||
	done        = "[DONE]"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
 | 
			
		||||
	// request stream endpoint
 | 
			
		||||
	streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "new request to stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
	streamReq.Header.Set("Accept", "text/event-stream")
 | 
			
		||||
	streamReq.Header.Set("Cache-Control", "no-store")
 | 
			
		||||
 | 
			
		||||
	resp, err := http.DefaultClient.Do(streamReq)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "do request to stream")
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	doneRendered := false
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		line := strings.TrimSpace(scanner.Text())
 | 
			
		||||
		if line == "" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Handle comments starting with ':'
 | 
			
		||||
		if strings.HasPrefix(line, ":") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Parse SSE fields
 | 
			
		||||
		if strings.HasPrefix(line, eventPrefix) {
 | 
			
		||||
			event := strings.TrimSpace(line[len(eventPrefix):])
 | 
			
		||||
			var data string
 | 
			
		||||
			// Read the following lines to get data and id
 | 
			
		||||
			for scanner.Scan() {
 | 
			
		||||
				nextLine := scanner.Text()
 | 
			
		||||
				if nextLine == "" {
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				if strings.HasPrefix(nextLine, dataPrefix) {
 | 
			
		||||
					data = nextLine[len(dataPrefix):]
 | 
			
		||||
				} else if strings.HasPrefix(nextLine, "id:") {
 | 
			
		||||
					// id = strings.TrimSpace(nextLine[len("id:"):])
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if event == "output" {
 | 
			
		||||
				render.StringData(c, data)
 | 
			
		||||
				responseText += data
 | 
			
		||||
			} else if event == "done" {
 | 
			
		||||
				render.Done(c)
 | 
			
		||||
				doneRendered = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		return "", errors.Wrap(err, "scan stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !doneRendered {
 | 
			
		||||
		render.Done(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return responseText, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								relay/adaptor/replicate/constant.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
// ModelList is a list of models that can be used with Replicate.
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/pricing
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// image model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro",
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra",
 | 
			
		||||
	"black-forest-labs/flux-canny-dev",
 | 
			
		||||
	"black-forest-labs/flux-canny-pro",
 | 
			
		||||
	"black-forest-labs/flux-depth-dev",
 | 
			
		||||
	"black-forest-labs/flux-depth-pro",
 | 
			
		||||
	"black-forest-labs/flux-dev",
 | 
			
		||||
	"black-forest-labs/flux-dev-lora",
 | 
			
		||||
	"black-forest-labs/flux-fill-dev",
 | 
			
		||||
	"black-forest-labs/flux-fill-pro",
 | 
			
		||||
	"black-forest-labs/flux-pro",
 | 
			
		||||
	"black-forest-labs/flux-redux-dev",
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell",
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora",
 | 
			
		||||
	"ideogram-ai/ideogram-v2",
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo",
 | 
			
		||||
	"recraft-ai/recraft-v3",
 | 
			
		||||
	"recraft-ai/recraft-v3-svg",
 | 
			
		||||
	"stability-ai/stable-diffusion-3",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo",
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// language model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	"ibm-granite/granite-20b-code-instruct-8k",
 | 
			
		||||
	"ibm-granite/granite-3.0-2b-instruct",
 | 
			
		||||
	"ibm-granite/granite-3.0-8b-instruct",
 | 
			
		||||
	"ibm-granite/granite-8b-code-instruct-128k",
 | 
			
		||||
	"meta/llama-2-13b",
 | 
			
		||||
	"meta/llama-2-13b-chat",
 | 
			
		||||
	"meta/llama-2-70b",
 | 
			
		||||
	"meta/llama-2-70b-chat",
 | 
			
		||||
	"meta/llama-2-7b",
 | 
			
		||||
	"meta/llama-2-7b-chat",
 | 
			
		||||
	"meta/meta-llama-3.1-405b-instruct",
 | 
			
		||||
	"meta/meta-llama-3-70b",
 | 
			
		||||
	"meta/meta-llama-3-70b-instruct",
 | 
			
		||||
	"meta/meta-llama-3-8b",
 | 
			
		||||
	"meta/meta-llama-3-8b-instruct",
 | 
			
		||||
	"mistralai/mistral-7b-instruct-v0.2",
 | 
			
		||||
	"mistralai/mistral-7b-v0.1",
 | 
			
		||||
	"mistralai/mixtral-8x7b-instruct-v0.1",
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// video model
 | 
			
		||||
	// -------------------------------------
 | 
			
		||||
	// "minimax/video-01",  // TODO: implement the adaptor
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								relay/adaptor/replicate/image.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,222 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/png"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"golang.org/x/image/webp"
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ImagesEditsHandler just copy response body to client
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro
 | 
			
		||||
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
// 	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
// 	for k, v := range resp.Header {
 | 
			
		||||
// 		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
// 	}
 | 
			
		||||
 | 
			
		||||
// 	if _, err := io.Copy(c.Writer, resp.Body); err != nil {
 | 
			
		||||
// 		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
// 	}
 | 
			
		||||
// 	defer resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
// 	return nil, nil
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
var errNextLoop = errors.New("next_loop")
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	if resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		payload, _ := io.ReadAll(resp.Body)
 | 
			
		||||
		return openai.ErrorWrapper(
 | 
			
		||||
				errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
 | 
			
		||||
				"bad_status_code", http.StatusInternalServerError),
 | 
			
		||||
			nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respData := new(ImageResponse)
 | 
			
		||||
	if err = json.Unmarshal(respBody, respData); err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err = func() error {
 | 
			
		||||
			// get task
 | 
			
		||||
			taskReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
				http.MethodGet, respData.URLs.Get, nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "new request")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
 | 
			
		||||
			taskResp, err := http.DefaultClient.Do(taskReq)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get task")
 | 
			
		||||
			}
 | 
			
		||||
			defer taskResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
			if taskResp.StatusCode != http.StatusOK {
 | 
			
		||||
				payload, _ := io.ReadAll(taskResp.Body)
 | 
			
		||||
				return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
					taskResp.StatusCode, string(payload))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskBody, err := io.ReadAll(taskResp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "read task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			taskData := new(ImageResponse)
 | 
			
		||||
			if err = json.Unmarshal(taskBody, taskData); err != nil {
 | 
			
		||||
				return errors.Wrap(err, "decode task response")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch taskData.Status {
 | 
			
		||||
			case "succeeded":
 | 
			
		||||
			case "failed", "canceled":
 | 
			
		||||
				return errors.Errorf("task failed: %s", taskData.Status)
 | 
			
		||||
			default:
 | 
			
		||||
				time.Sleep(time.Second * 3)
 | 
			
		||||
				return errNextLoop
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			output, err := taskData.GetOutput()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.Wrap(err, "get output")
 | 
			
		||||
			}
 | 
			
		||||
			if len(output) == 0 {
 | 
			
		||||
				return errors.New("response output is empty")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var mu sync.Mutex
 | 
			
		||||
			var pool errgroup.Group
 | 
			
		||||
			respBody := &openai.ImageResponse{
 | 
			
		||||
				Created: taskData.CompletedAt.Unix(),
 | 
			
		||||
				Data:    []openai.ImageData{},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, imgOut := range output {
 | 
			
		||||
				imgOut := imgOut
 | 
			
		||||
				pool.Go(func() error {
 | 
			
		||||
					// download image
 | 
			
		||||
					downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
 | 
			
		||||
						http.MethodGet, imgOut, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "new request")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgResp, err := http.DefaultClient.Do(downloadReq)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "download image")
 | 
			
		||||
					}
 | 
			
		||||
					defer imgResp.Body.Close()
 | 
			
		||||
 | 
			
		||||
					if imgResp.StatusCode != http.StatusOK {
 | 
			
		||||
						payload, _ := io.ReadAll(imgResp.Body)
 | 
			
		||||
						return errors.Errorf("bad status code [%d]%s",
 | 
			
		||||
							imgResp.StatusCode, string(payload))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err := io.ReadAll(imgResp.Body)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "read image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					imgData, err = ConvertImageToPNG(imgData)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.Wrap(err, "convert image")
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					mu.Lock()
 | 
			
		||||
					respBody.Data = append(respBody.Data, openai.ImageData{
 | 
			
		||||
						B64Json: fmt.Sprintf("data:image/png;base64,%s",
 | 
			
		||||
							base64.StdEncoding.EncodeToString(imgData)),
 | 
			
		||||
					})
 | 
			
		||||
					mu.Unlock()
 | 
			
		||||
 | 
			
		||||
					return nil
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := pool.Wait(); err != nil {
 | 
			
		||||
				if len(respBody.Data) == 0 {
 | 
			
		||||
					return errors.WithStack(err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			c.JSON(http.StatusOK, respBody)
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, errNextLoop) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertImageToPNG converts a WebP image to PNG format
 | 
			
		||||
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
 | 
			
		||||
	// bypass if it's already a PNG image
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
 | 
			
		||||
		return webpData, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if is jpeg, convert to png
 | 
			
		||||
	if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
 | 
			
		||||
		img, _, err := image.Decode(bytes.NewReader(webpData))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "decode jpeg")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var pngBuffer bytes.Buffer
 | 
			
		||||
		if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return pngBuffer.Bytes(), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Decode the WebP image
 | 
			
		||||
	img, err := webp.Decode(bytes.NewReader(webpData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "decode webp")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode the image as PNG
 | 
			
		||||
	var pngBuffer bytes.Buffer
 | 
			
		||||
	if err := png.Encode(&pngBuffer, img); err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "encode png")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return pngBuffer.Bytes(), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								relay/adaptor/replicate/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,159 @@
 | 
			
		||||
package replicate
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DrawImageRequest draw image by fluxpro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type DrawImageRequest struct {
 | 
			
		||||
	Input ImageInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
 | 
			
		||||
type ImageInput struct {
 | 
			
		||||
	Steps           int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt          string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	ImagePrompt     string `json:"image_prompt"`
 | 
			
		||||
	Guidance        int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	Interval        int    `json:"interval" binding:"required,min=1,max=4"`
 | 
			
		||||
	AspectRatio     string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
 | 
			
		||||
	SafetyTolerance int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	Seed            int    `json:"seed"`
 | 
			
		||||
	NImages         int    `json:"n_images" binding:"required,min=1,max=8"`
 | 
			
		||||
	Width           int    `json:"width" binding:"required,min=256,max=1440"`
 | 
			
		||||
	Height          int    `json:"height" binding:"required,min=256,max=1440"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type InpaintingImageByFlusReplicateRequest struct {
 | 
			
		||||
	Input FluxInpaintingInput `json:"input"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxInpaintingInput is input of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
 | 
			
		||||
type FluxInpaintingInput struct {
 | 
			
		||||
	Mask             string `json:"mask" binding:"required"`
 | 
			
		||||
	Image            string `json:"image" binding:"required"`
 | 
			
		||||
	Seed             int    `json:"seed"`
 | 
			
		||||
	Steps            int    `json:"steps" binding:"required,min=1"`
 | 
			
		||||
	Prompt           string `json:"prompt" binding:"required,min=5"`
 | 
			
		||||
	Guidance         int    `json:"guidance" binding:"required,min=2,max=5"`
 | 
			
		||||
	OutputFormat     string `json:"output_format"`
 | 
			
		||||
	SafetyTolerance  int    `json:"safety_tolerance" binding:"required,min=1,max=5"`
 | 
			
		||||
	PromptUnsampling bool   `json:"prompt_unsampling"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageResponse is response of DrawImageByFluxProRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	CompletedAt time.Time        `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time        `json:"created_at"`
 | 
			
		||||
	DataRemoved bool             `json:"data_removed"`
 | 
			
		||||
	Error       string           `json:"error"`
 | 
			
		||||
	ID          string           `json:"id"`
 | 
			
		||||
	Input       DrawImageRequest `json:"input"`
 | 
			
		||||
	Logs        string           `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics      `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    any       `json:"output"`
 | 
			
		||||
	StartedAt time.Time `json:"started_at"`
 | 
			
		||||
	Status    string    `json:"status"`
 | 
			
		||||
	URLs      FluxURLs  `json:"urls"`
 | 
			
		||||
	Version   string    `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ImageResponse) GetOutput() ([]string, error) {
 | 
			
		||||
	switch v := r.Output.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return []string{v}, nil
 | 
			
		||||
	case []string:
 | 
			
		||||
		return v, nil
 | 
			
		||||
	case nil:
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	case []interface{}:
 | 
			
		||||
		// convert []interface{} to []string
 | 
			
		||||
		ret := make([]string, len(v))
 | 
			
		||||
		for idx, vv := range v {
 | 
			
		||||
			if vvv, ok := vv.(string); ok {
 | 
			
		||||
				ret[idx] = vvv
 | 
			
		||||
			} else {
 | 
			
		||||
				return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return ret, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxMetrics is metrics of ImageResponse
 | 
			
		||||
type FluxMetrics struct {
 | 
			
		||||
	ImageCount  int     `json:"image_count"`
 | 
			
		||||
	PredictTime float64 `json:"predict_time"`
 | 
			
		||||
	TotalTime   float64 `json:"total_time"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FluxURLs is urls of ImageResponse
 | 
			
		||||
type FluxURLs struct {
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ReplicateChatRequest struct {
 | 
			
		||||
	Input ChatInput `json:"input" form:"input" binding:"required"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatInput is input of ChatByReplicateRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
 | 
			
		||||
type ChatInput struct {
 | 
			
		||||
	TopK             int     `json:"top_k"`
 | 
			
		||||
	TopP             float64 `json:"top_p"`
 | 
			
		||||
	Prompt           string  `json:"prompt"`
 | 
			
		||||
	MaxTokens        int     `json:"max_tokens"`
 | 
			
		||||
	MinTokens        int     `json:"min_tokens"`
 | 
			
		||||
	Temperature      float64 `json:"temperature"`
 | 
			
		||||
	SystemPrompt     string  `json:"system_prompt"`
 | 
			
		||||
	StopSequences    string  `json:"stop_sequences"`
 | 
			
		||||
	PromptTemplate   string  `json:"prompt_template"`
 | 
			
		||||
	PresencePenalty  float64 `json:"presence_penalty"`
 | 
			
		||||
	FrequencyPenalty float64 `json:"frequency_penalty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatResponse is response of ChatByReplicateRequest
 | 
			
		||||
//
 | 
			
		||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
 | 
			
		||||
type ChatResponse struct {
 | 
			
		||||
	CompletedAt time.Time   `json:"completed_at"`
 | 
			
		||||
	CreatedAt   time.Time   `json:"created_at"`
 | 
			
		||||
	DataRemoved bool        `json:"data_removed"`
 | 
			
		||||
	Error       string      `json:"error"`
 | 
			
		||||
	ID          string      `json:"id"`
 | 
			
		||||
	Input       ChatInput   `json:"input"`
 | 
			
		||||
	Logs        string      `json:"logs"`
 | 
			
		||||
	Metrics     FluxMetrics `json:"metrics"`
 | 
			
		||||
	// Output could be `string` or `[]string`
 | 
			
		||||
	Output    []string        `json:"output"`
 | 
			
		||||
	StartedAt time.Time       `json:"started_at"`
 | 
			
		||||
	Status    string          `json:"status"`
 | 
			
		||||
	URLs      ChatResponseUrl `json:"urls"`
 | 
			
		||||
	Version   string          `json:"version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatResponseUrl is task urls of ChatResponse
 | 
			
		||||
type ChatResponseUrl struct {
 | 
			
		||||
	Stream string `json:"stream"`
 | 
			
		||||
	Get    string `json:"get"`
 | 
			
		||||
	Cancel string `json:"cancel"`
 | 
			
		||||
}
 | 
			
		||||
@@ -15,7 +15,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", 
 | 
			
		||||
	"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -19,6 +19,7 @@ const (
 | 
			
		||||
	DeepL
 | 
			
		||||
	VertexAI
 | 
			
		||||
	Proxy
 | 
			
		||||
	Replicate
 | 
			
		||||
 | 
			
		||||
	Dummy // this one is only for count, do not add any channel after this
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -211,6 +211,50 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"deepl-ja": 25.0 / 1000 * USD,
 | 
			
		||||
	// https://console.x.ai/
 | 
			
		||||
	"grok-beta": 5.0 / 1000 * USD,
 | 
			
		||||
	// replicate charges based on the number of generated images
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro":                0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-1.1-pro-ultra":          0.06 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-canny-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-depth-pro":              0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev":                    0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-dev-lora":               0.032 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-dev":               0.04 * USD,
 | 
			
		||||
	"black-forest-labs/flux-fill-pro":               0.05 * USD,
 | 
			
		||||
	"black-forest-labs/flux-pro":                    0.055 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-dev":              0.025 * USD,
 | 
			
		||||
	"black-forest-labs/flux-redux-schnell":          0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell":                0.003 * USD,
 | 
			
		||||
	"black-forest-labs/flux-schnell-lora":           0.02 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2":                       0.08 * USD,
 | 
			
		||||
	"ideogram-ai/ideogram-v2-turbo":                 0.05 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3":                         0.04 * USD,
 | 
			
		||||
	"recraft-ai/recraft-v3-svg":                     0.08 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3":               0.035 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large":       0.065 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
 | 
			
		||||
	"stability-ai/stable-diffusion-3.5-medium":      0.035 * USD,
 | 
			
		||||
	// replicate chat models
 | 
			
		||||
	"ibm-granite/granite-20b-code-instruct-8k":  0.100 * USD,
 | 
			
		||||
	"ibm-granite/granite-3.0-2b-instruct":       0.030 * USD,
 | 
			
		||||
	"ibm-granite/granite-3.0-8b-instruct":       0.050 * USD,
 | 
			
		||||
	"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
 | 
			
		||||
	"meta/llama-2-13b":                          0.100 * USD,
 | 
			
		||||
	"meta/llama-2-13b-chat":                     0.100 * USD,
 | 
			
		||||
	"meta/llama-2-70b":                          0.650 * USD,
 | 
			
		||||
	"meta/llama-2-70b-chat":                     0.650 * USD,
 | 
			
		||||
	"meta/llama-2-7b":                           0.050 * USD,
 | 
			
		||||
	"meta/llama-2-7b-chat":                      0.050 * USD,
 | 
			
		||||
	"meta/meta-llama-3.1-405b-instruct":         9.500 * USD,
 | 
			
		||||
	"meta/meta-llama-3-70b":                     0.650 * USD,
 | 
			
		||||
	"meta/meta-llama-3-70b-instruct":            0.650 * USD,
 | 
			
		||||
	"meta/meta-llama-3-8b":                      0.050 * USD,
 | 
			
		||||
	"meta/meta-llama-3-8b-instruct":             0.050 * USD,
 | 
			
		||||
	"mistralai/mistral-7b-instruct-v0.2":        0.050 * USD,
 | 
			
		||||
	"mistralai/mistral-7b-v0.1":                 0.050 * USD,
 | 
			
		||||
	"mistralai/mixtral-8x7b-instruct-v0.1":      0.300 * USD,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var CompletionRatio = map[string]float64{
 | 
			
		||||
@@ -362,6 +406,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "deepseek-") {
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch name {
 | 
			
		||||
	case "llama2-70b-4096":
 | 
			
		||||
		return 0.8 / 0.64
 | 
			
		||||
@@ -377,6 +422,35 @@ func GetCompletionRatio(name string, channelType int) float64 {
 | 
			
		||||
		return 5
 | 
			
		||||
	case "grok-beta":
 | 
			
		||||
		return 3
 | 
			
		||||
	// Replicate Models
 | 
			
		||||
	// https://replicate.com/pricing
 | 
			
		||||
	case "ibm-granite/granite-20b-code-instruct-8k":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "ibm-granite/granite-3.0-2b-instruct":
 | 
			
		||||
		return 8.333333333333334
 | 
			
		||||
	case "ibm-granite/granite-3.0-8b-instruct",
 | 
			
		||||
		"ibm-granite/granite-8b-code-instruct-128k":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "meta/llama-2-13b",
 | 
			
		||||
		"meta/llama-2-13b-chat",
 | 
			
		||||
		"meta/llama-2-7b",
 | 
			
		||||
		"meta/llama-2-7b-chat",
 | 
			
		||||
		"meta/meta-llama-3-8b",
 | 
			
		||||
		"meta/meta-llama-3-8b-instruct":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "meta/llama-2-70b",
 | 
			
		||||
		"meta/llama-2-70b-chat",
 | 
			
		||||
		"meta/meta-llama-3-70b",
 | 
			
		||||
		"meta/meta-llama-3-70b-instruct":
 | 
			
		||||
		return 2.750 / 0.650 // ≈4.230769
 | 
			
		||||
	case "meta/meta-llama-3.1-405b-instruct":
 | 
			
		||||
		return 1
 | 
			
		||||
	case "mistralai/mistral-7b-instruct-v0.2",
 | 
			
		||||
		"mistralai/mistral-7b-v0.1":
 | 
			
		||||
		return 5
 | 
			
		||||
	case "mistralai/mixtral-8x7b-instruct-v0.1":
 | 
			
		||||
		return 1.000 / 0.300 // ≈3.333333
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -47,5 +47,6 @@ const (
 | 
			
		||||
	Proxy
 | 
			
		||||
	SiliconFlow
 | 
			
		||||
	XAI
 | 
			
		||||
	Replicate
 | 
			
		||||
	Dummy
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
 | 
			
		||||
		apiType = apitype.DeepL
 | 
			
		||||
	case VertextAI:
 | 
			
		||||
		apiType = apitype.VertexAI
 | 
			
		||||
	case Replicate:
 | 
			
		||||
		apiType = apitype.Replicate
 | 
			
		||||
	case Proxy:
 | 
			
		||||
		apiType = apitype.Proxy
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                                          // 43
 | 
			
		||||
	"https://api.siliconflow.cn",                // 44
 | 
			
		||||
	"https://api.x.ai",                          // 45
 | 
			
		||||
	"https://api.replicate.com/v1/models/",      // 46
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
 
 | 
			
		||||
@@ -147,14 +147,20 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
	if resp.StatusCode != http.StatusOK &&
 | 
			
		||||
		// replicate return 201 to create a task
 | 
			
		||||
		resp.StatusCode != http.StatusCreated {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if meta.ChannelType == channeltype.DeepL {
 | 
			
		||||
		// skip stream check for deepl
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
 | 
			
		||||
 | 
			
		||||
	if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
 | 
			
		||||
		// Even if stream mode is enabled, replicate will first return a task info in JSON format,
 | 
			
		||||
		// requiring the client to request the stream endpoint in the task info
 | 
			
		||||
		meta.ChannelType != channeltype.Replicate {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
 
 | 
			
		||||
@@ -22,7 +22,7 @@ import (
 | 
			
		||||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
 | 
			
		||||
	imageRequest := &relaymodel.ImageRequest{}
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// check prompt length
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
@@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
	adaptor.Init(meta)
 | 
			
		||||
 | 
			
		||||
	// these adaptors need to convert the request
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Ali:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case channeltype.Baidu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case channeltype.Zhipu:
 | 
			
		||||
	case channeltype.Zhipu,
 | 
			
		||||
		channeltype.Ali,
 | 
			
		||||
		channeltype.Replicate,
 | 
			
		||||
		channeltype.Baidu:
 | 
			
		||||
		finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
@@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
 | 
			
		||||
 | 
			
		||||
	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
 | 
			
		||||
	var quota int64
 | 
			
		||||
	switch meta.ChannelType {
 | 
			
		||||
	case channeltype.Replicate:
 | 
			
		||||
		// replicate always return 1 image
 | 
			
		||||
		quota = int64(ratio * imageCostRatio * 1000)
 | 
			
		||||
	default:
 | 
			
		||||
		quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userQuota-quota < 0 {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
@@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		if resp != nil && resp.StatusCode != http.StatusOK {
 | 
			
		||||
		if resp != nil &&
 | 
			
		||||
			resp.StatusCode != http.StatusCreated && // replicate returns 201
 | 
			
		||||
			resp.StatusCode != http.StatusOK {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 43, text: 'Proxy', value: 43, color: 'blue' },
 | 
			
		||||
  { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
 | 
			
		||||
  { key: 45, text: 'xAI', value: 45, color: 'blue' },
 | 
			
		||||
  { key: 46, text: 'Replicate', value: 46, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
 | 
			
		||||
    value: 45,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  45: {
 | 
			
		||||
    key: 46,
 | 
			
		||||
    text: 'Replicate',
 | 
			
		||||
    value: 46,
 | 
			
		||||
    color: 'primary'
 | 
			
		||||
  },
 | 
			
		||||
  41: {
 | 
			
		||||
    key: 41,
 | 
			
		||||
    text: 'Novita',
 | 
			
		||||
 
 | 
			
		||||
@@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
    { key: 43, text: 'Proxy', value: 43, color: 'blue' },
 | 
			
		||||
    { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
 | 
			
		||||
    { key: 45, text: 'xAI', value: 45, color: 'blue' },
 | 
			
		||||
    { key: 46, text: 'Replicate', value: 46, color: 'blue' },
 | 
			
		||||
    { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
    { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
    { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user