mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
feat: support replicate chat models (#1989)
* feat: add Replicate adaptor and integrate into channel and API types * feat: support llm chat on replicate
This commit is contained in:
parent
36c8f4f15c
commit
305ce14fe3
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,3 +10,4 @@ data
|
|||||||
/web/node_modules
|
/web/node_modules
|
||||||
cmd.md
|
cmd.md
|
||||||
.env
|
.env
|
||||||
|
/one-api
|
||||||
|
@ -3,9 +3,10 @@ package render
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func StringData(c *gin.Context, str string) {
|
func StringData(c *gin.Context, str string) {
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
"github.com/songquanpeng/one-api/relay/adaptor/palm"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
|
||||||
@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
|||||||
return &vertexai.Adaptor{}
|
return &vertexai.Adaptor{}
|
||||||
case apitype.Proxy:
|
case apitype.Proxy:
|
||||||
return &proxy.Adaptor{}
|
return &proxy.Adaptor{}
|
||||||
|
case apitype.Replicate:
|
||||||
|
return &replicate.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -2,15 +2,16 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
|
||||||
usage := &model.Usage{}
|
usage := &model.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
usage.CompletionTokens = CountTokenText(responseText, modelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,16 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/relay/model"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
|
||||||
|
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
|
||||||
|
|
||||||
Error := model.Error{
|
Error := model.Error{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "one_api_error",
|
||||||
|
136
relay/adaptor/replicate/adaptor.go
Normal file
136
relay/adaptor/replicate/adaptor.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
meta *meta.Meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
|
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
return DrawImageRequest{
|
||||||
|
Input: ImageInput{
|
||||||
|
Steps: 25,
|
||||||
|
Prompt: request.Prompt,
|
||||||
|
Guidance: 3,
|
||||||
|
Seed: int(time.Now().UnixNano()),
|
||||||
|
SafetyTolerance: 5,
|
||||||
|
NImages: 1, // replicate will always return 1 image
|
||||||
|
Width: 1440,
|
||||||
|
Height: 1440,
|
||||||
|
AspectRatio: "1:1",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if !request.Stream {
|
||||||
|
// TODO: support non-stream mode
|
||||||
|
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the prompt from OpenAI messages
|
||||||
|
var promptBuilder strings.Builder
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
switch msgCnt := message.Content.(type) {
|
||||||
|
case string:
|
||||||
|
promptBuilder.WriteString(message.Role)
|
||||||
|
promptBuilder.WriteString(": ")
|
||||||
|
promptBuilder.WriteString(msgCnt)
|
||||||
|
promptBuilder.WriteString("\n")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
replicateRequest := ReplicateChatRequest{
|
||||||
|
Input: ChatInput{
|
||||||
|
Prompt: promptBuilder.String(),
|
||||||
|
MaxTokens: request.MaxTokens,
|
||||||
|
Temperature: 1.0,
|
||||||
|
TopP: 1.0,
|
||||||
|
PresencePenalty: 0.0,
|
||||||
|
FrequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map optional fields
|
||||||
|
if request.Temperature != nil {
|
||||||
|
replicateRequest.Input.Temperature = *request.Temperature
|
||||||
|
}
|
||||||
|
if request.TopP != nil {
|
||||||
|
replicateRequest.Input.TopP = *request.TopP
|
||||||
|
}
|
||||||
|
if request.PresencePenalty != nil {
|
||||||
|
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
|
||||||
|
}
|
||||||
|
if request.FrequencyPenalty != nil {
|
||||||
|
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
|
||||||
|
}
|
||||||
|
if request.MaxTokens > 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = request.MaxTokens
|
||||||
|
} else if request.MaxTokens == 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return replicateRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
|
a.meta = meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
|
if !slices.Contains(ModelList, meta.OriginModelName) {
|
||||||
|
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||||
|
adaptor.SetupCommonRequestHeader(c, req, meta)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
logger.Info(c, "send request to replicate")
|
||||||
|
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
err, usage = ImageHandler(c, resp)
|
||||||
|
case relaymode.ChatCompletions:
|
||||||
|
err, usage = ChatHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "replicate"
|
||||||
|
}
|
191
relay/adaptor/replicate/chat.go
Normal file
191
relay/adaptor/replicate/chat.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatHandler(c *gin.Context, resp *http.Response) (
|
||||||
|
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskData.URLs.Stream == "" {
|
||||||
|
return errors.New("stream url is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// request stream url
|
||||||
|
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "chat stream handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxMeta := meta.GetByContext(c)
|
||||||
|
usage = openai.ResponseText2Usage(responseText,
|
||||||
|
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventPrefix = "event: "
|
||||||
|
dataPrefix = "data: "
|
||||||
|
done = "[DONE]"
|
||||||
|
)
|
||||||
|
|
||||||
|
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
||||||
|
// request stream endpoint
|
||||||
|
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "new request to stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
streamReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
streamReq.Header.Set("Cache-Control", "no-store")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(streamReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "do request to stream")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
common.SetEventStreamHeaders(c)
|
||||||
|
doneRendered := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle comments starting with ':'
|
||||||
|
if strings.HasPrefix(line, ":") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE fields
|
||||||
|
if strings.HasPrefix(line, eventPrefix) {
|
||||||
|
event := strings.TrimSpace(line[len(eventPrefix):])
|
||||||
|
var data string
|
||||||
|
// Read the following lines to get data and id
|
||||||
|
for scanner.Scan() {
|
||||||
|
nextLine := scanner.Text()
|
||||||
|
if nextLine == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(nextLine, dataPrefix) {
|
||||||
|
data = nextLine[len(dataPrefix):]
|
||||||
|
} else if strings.HasPrefix(nextLine, "id:") {
|
||||||
|
// id = strings.TrimSpace(nextLine[len("id:"):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if event == "output" {
|
||||||
|
render.StringData(c, data)
|
||||||
|
responseText += data
|
||||||
|
} else if event == "done" {
|
||||||
|
render.Done(c)
|
||||||
|
doneRendered = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", errors.Wrap(err, "scan stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !doneRendered {
|
||||||
|
render.Done(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseText, nil
|
||||||
|
}
|
58
relay/adaptor/replicate/constant.go
Normal file
58
relay/adaptor/replicate/constant.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
// ModelList is a list of models that can be used with Replicate.
|
||||||
|
//
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
var ModelList = []string{
|
||||||
|
// -------------------------------------
|
||||||
|
// image model
|
||||||
|
// -------------------------------------
|
||||||
|
"black-forest-labs/flux-1.1-pro",
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra",
|
||||||
|
"black-forest-labs/flux-canny-dev",
|
||||||
|
"black-forest-labs/flux-canny-pro",
|
||||||
|
"black-forest-labs/flux-depth-dev",
|
||||||
|
"black-forest-labs/flux-depth-pro",
|
||||||
|
"black-forest-labs/flux-dev",
|
||||||
|
"black-forest-labs/flux-dev-lora",
|
||||||
|
"black-forest-labs/flux-fill-dev",
|
||||||
|
"black-forest-labs/flux-fill-pro",
|
||||||
|
"black-forest-labs/flux-pro",
|
||||||
|
"black-forest-labs/flux-redux-dev",
|
||||||
|
"black-forest-labs/flux-redux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell",
|
||||||
|
"black-forest-labs/flux-schnell-lora",
|
||||||
|
"ideogram-ai/ideogram-v2",
|
||||||
|
"ideogram-ai/ideogram-v2-turbo",
|
||||||
|
"recraft-ai/recraft-v3",
|
||||||
|
"recraft-ai/recraft-v3-svg",
|
||||||
|
"stability-ai/stable-diffusion-3",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large",
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo",
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium",
|
||||||
|
// -------------------------------------
|
||||||
|
// language model
|
||||||
|
// -------------------------------------
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k",
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct",
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k",
|
||||||
|
"meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3.1-405b-instruct",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct",
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1",
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
// -------------------------------------
|
||||||
|
// video model
|
||||||
|
// -------------------------------------
|
||||||
|
// "minimax/video-01", // TODO: implement the adaptor
|
||||||
|
}
|
222
relay/adaptor/replicate/image.go
Normal file
222
relay/adaptor/replicate/image.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"golang.org/x/image/webp"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImagesEditsHandler just copy response body to client
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro
|
||||||
|
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
// c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
// for k, v := range resp.Header {
|
||||||
|
// c.Writer.Header().Set(k, v[0])
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
|
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
|
||||||
|
// return nil, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
var errNextLoop = errors.New("next_loop")
|
||||||
|
|
||||||
|
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ImageResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed: %s", taskData.Status)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := taskData.GetOutput()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get output")
|
||||||
|
}
|
||||||
|
if len(output) == 0 {
|
||||||
|
return errors.New("response output is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var pool errgroup.Group
|
||||||
|
respBody := &openai.ImageResponse{
|
||||||
|
Created: taskData.CompletedAt.Unix(),
|
||||||
|
Data: []openai.ImageData{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, imgOut := range output {
|
||||||
|
imgOut := imgOut
|
||||||
|
pool.Go(func() error {
|
||||||
|
// download image
|
||||||
|
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, imgOut, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgResp, err := http.DefaultClient.Do(downloadReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "download image")
|
||||||
|
}
|
||||||
|
defer imgResp.Body.Close()
|
||||||
|
|
||||||
|
if imgResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(imgResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
imgResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err := io.ReadAll(imgResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read image")
|
||||||
|
}
|
||||||
|
|
||||||
|
imgData, err = ConvertImageToPNG(imgData)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "convert image")
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
respBody.Data = append(respBody.Data, openai.ImageData{
|
||||||
|
B64Json: fmt.Sprintf("data:image/png;base64,%s",
|
||||||
|
base64.StdEncoding.EncodeToString(imgData)),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.Wait(); err != nil {
|
||||||
|
if len(respBody.Data) == 0 {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, respBody)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageToPNG converts a WebP image to PNG format
|
||||||
|
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
|
||||||
|
// bypass if it's already a PNG image
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
|
||||||
|
return webpData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if is jpeg, convert to png
|
||||||
|
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode jpeg")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the WebP image
|
||||||
|
img, err := webp.Decode(bytes.NewReader(webpData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "decode webp")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the image as PNG
|
||||||
|
var pngBuffer bytes.Buffer
|
||||||
|
if err := png.Encode(&pngBuffer, img); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "encode png")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pngBuffer.Bytes(), nil
|
||||||
|
}
|
159
relay/adaptor/replicate/model.go
Normal file
159
relay/adaptor/replicate/model.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DrawImageRequest draw image by fluxpro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type DrawImageRequest struct {
|
||||||
|
Input ImageInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
|
||||||
|
type ImageInput struct {
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
ImagePrompt string `json:"image_prompt"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
Interval int `json:"interval" binding:"required,min=1,max=4"`
|
||||||
|
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
NImages int `json:"n_images" binding:"required,min=1,max=8"`
|
||||||
|
Width int `json:"width" binding:"required,min=256,max=1440"`
|
||||||
|
Height int `json:"height" binding:"required,min=256,max=1440"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type InpaintingImageByFlusReplicateRequest struct {
|
||||||
|
Input FluxInpaintingInput `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxInpaintingInput is input of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
|
||||||
|
type FluxInpaintingInput struct {
|
||||||
|
Mask string `json:"mask" binding:"required"`
|
||||||
|
Image string `json:"image" binding:"required"`
|
||||||
|
Seed int `json:"seed"`
|
||||||
|
Steps int `json:"steps" binding:"required,min=1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required,min=5"`
|
||||||
|
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
|
||||||
|
OutputFormat string `json:"output_format"`
|
||||||
|
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
|
||||||
|
PromptUnsampling bool `json:"prompt_unsampling"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponse is response of DrawImageByFluxProRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
|
||||||
|
type ImageResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input DrawImageRequest `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output any `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs FluxURLs `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ImageResponse) GetOutput() ([]string, error) {
|
||||||
|
switch v := r.Output.(type) {
|
||||||
|
case string:
|
||||||
|
return []string{v}, nil
|
||||||
|
case []string:
|
||||||
|
return v, nil
|
||||||
|
case nil:
|
||||||
|
return nil, nil
|
||||||
|
case []interface{}:
|
||||||
|
// convert []interface{} to []string
|
||||||
|
ret := make([]string, len(v))
|
||||||
|
for idx, vv := range v {
|
||||||
|
if vvv, ok := vv.(string); ok {
|
||||||
|
ret[idx] = vvv
|
||||||
|
} else {
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxMetrics is metrics of ImageResponse
|
||||||
|
type FluxMetrics struct {
|
||||||
|
ImageCount int `json:"image_count"`
|
||||||
|
PredictTime float64 `json:"predict_time"`
|
||||||
|
TotalTime float64 `json:"total_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FluxURLs is urls of ImageResponse
|
||||||
|
type FluxURLs struct {
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReplicateChatRequest struct {
|
||||||
|
Input ChatInput `json:"input" form:"input" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatInput is input of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
|
||||||
|
type ChatInput struct {
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
MinTokens int `json:"min_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
SystemPrompt string `json:"system_prompt"`
|
||||||
|
StopSequences string `json:"stop_sequences"`
|
||||||
|
PromptTemplate string `json:"prompt_template"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is response of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
|
||||||
|
type ChatResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input ChatInput `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output []string `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs ChatResponseUrl `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponseUrl is task urls of ChatResponse
|
||||||
|
type ChatResponseUrl struct {
|
||||||
|
Stream string `json:"stream"`
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
@ -19,6 +19,7 @@ const (
|
|||||||
DeepL
|
DeepL
|
||||||
VertexAI
|
VertexAI
|
||||||
Proxy
|
Proxy
|
||||||
|
Replicate
|
||||||
|
|
||||||
Dummy // this one is only for count, do not add any channel after this
|
Dummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
@ -211,6 +211,50 @@ var ModelRatio = map[string]float64{
|
|||||||
"deepl-ja": 25.0 / 1000 * USD,
|
"deepl-ja": 25.0 / 1000 * USD,
|
||||||
// https://console.x.ai/
|
// https://console.x.ai/
|
||||||
"grok-beta": 5.0 / 1000 * USD,
|
"grok-beta": 5.0 / 1000 * USD,
|
||||||
|
// replicate charges based on the number of generated images
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
|
||||||
|
"black-forest-labs/flux-canny-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-canny-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-depth-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-depth-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-dev-lora": 0.032 * USD,
|
||||||
|
"black-forest-labs/flux-fill-dev": 0.04 * USD,
|
||||||
|
"black-forest-labs/flux-fill-pro": 0.05 * USD,
|
||||||
|
"black-forest-labs/flux-pro": 0.055 * USD,
|
||||||
|
"black-forest-labs/flux-redux-dev": 0.025 * USD,
|
||||||
|
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell": 0.003 * USD,
|
||||||
|
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2": 0.08 * USD,
|
||||||
|
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
|
||||||
|
"recraft-ai/recraft-v3": 0.04 * USD,
|
||||||
|
"recraft-ai/recraft-v3-svg": 0.08 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3": 0.035 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||||
|
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||||
|
// replicate chat models
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
|
||||||
|
"meta/llama-2-13b": 0.100 * USD,
|
||||||
|
"meta/llama-2-13b-chat": 0.100 * USD,
|
||||||
|
"meta/llama-2-70b": 0.650 * USD,
|
||||||
|
"meta/llama-2-70b-chat": 0.650 * USD,
|
||||||
|
"meta/llama-2-7b": 0.050 * USD,
|
||||||
|
"meta/llama-2-7b-chat": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
|
||||||
|
"meta/meta-llama-3-70b": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-8b": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-v0.1": 0.050 * USD,
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{
|
var CompletionRatio = map[string]float64{
|
||||||
@ -362,6 +406,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
if strings.HasPrefix(name, "deepseek-") {
|
if strings.HasPrefix(name, "deepseek-") {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
case "llama2-70b-4096":
|
case "llama2-70b-4096":
|
||||||
return 0.8 / 0.64
|
return 0.8 / 0.64
|
||||||
@ -377,6 +422,35 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
return 5
|
return 5
|
||||||
case "grok-beta":
|
case "grok-beta":
|
||||||
return 3
|
return 3
|
||||||
|
// Replicate Models
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
case "ibm-granite/granite-20b-code-instruct-8k":
|
||||||
|
return 5
|
||||||
|
case "ibm-granite/granite-3.0-2b-instruct":
|
||||||
|
return 8.333333333333334
|
||||||
|
case "ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct":
|
||||||
|
return 2.750 / 0.650 // ≈4.230769
|
||||||
|
case "meta/meta-llama-3.1-405b-instruct":
|
||||||
|
return 1
|
||||||
|
case "mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1":
|
||||||
|
return 5
|
||||||
|
case "mistralai/mixtral-8x7b-instruct-v0.1":
|
||||||
|
return 1.000 / 0.300 // ≈3.333333
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
@ -47,5 +47,6 @@ const (
|
|||||||
Proxy
|
Proxy
|
||||||
SiliconFlow
|
SiliconFlow
|
||||||
XAI
|
XAI
|
||||||
|
Replicate
|
||||||
Dummy
|
Dummy
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
|
|||||||
apiType = apitype.DeepL
|
apiType = apitype.DeepL
|
||||||
case VertextAI:
|
case VertextAI:
|
||||||
apiType = apitype.VertexAI
|
apiType = apitype.VertexAI
|
||||||
|
case Replicate:
|
||||||
|
apiType = apitype.Replicate
|
||||||
case Proxy:
|
case Proxy:
|
||||||
apiType = apitype.Proxy
|
apiType = apitype.Proxy
|
||||||
}
|
}
|
||||||
|
@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"", // 43
|
"", // 43
|
||||||
"https://api.siliconflow.cn", // 44
|
"https://api.siliconflow.cn", // 44
|
||||||
"https://api.x.ai", // 45
|
"https://api.x.ai", // 45
|
||||||
|
"https://api.replicate.com/v1/models/", // 46
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -147,14 +147,20 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK &&
|
||||||
|
// replicate return 201 to create a task
|
||||||
|
resp.StatusCode != http.StatusCreated {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if meta.ChannelType == channeltype.DeepL {
|
if meta.ChannelType == channeltype.DeepL {
|
||||||
// skip stream check for deepl
|
// skip stream check for deepl
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
|
||||||
|
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
|
||||||
|
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
|
||||||
|
// requiring the client to request the stream endpoint in the task info
|
||||||
|
meta.ChannelType != channeltype.Replicate {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -22,7 +22,7 @@ import (
|
|||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
|
||||||
imageRequest := &relaymodel.ImageRequest{}
|
imageRequest := &relaymodel.ImageRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||||
// check prompt length
|
// check prompt length
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
|
// these adaptors need to convert the request
|
||||||
switch meta.ChannelType {
|
switch meta.ChannelType {
|
||||||
case channeltype.Ali:
|
case channeltype.Zhipu,
|
||||||
fallthrough
|
channeltype.Ali,
|
||||||
case channeltype.Baidu:
|
channeltype.Replicate,
|
||||||
fallthrough
|
channeltype.Baidu:
|
||||||
case channeltype.Zhipu:
|
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||||
|
|
||||||
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
var quota int64
|
||||||
|
switch meta.ChannelType {
|
||||||
|
case channeltype.Replicate:
|
||||||
|
// replicate always return 1 image
|
||||||
|
quota = int64(ratio * imageCostRatio * 1000)
|
||||||
|
default:
|
||||||
|
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||||
|
}
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
if resp != nil &&
|
||||||
|
resp.StatusCode != http.StatusCreated && // replicate returns 201
|
||||||
|
resp.StatusCode != http.StatusOK {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 45,
|
value: 45,
|
||||||
color: 'primary'
|
color: 'primary'
|
||||||
},
|
},
|
||||||
|
45: {
|
||||||
|
key: 46,
|
||||||
|
text: 'Replicate',
|
||||||
|
value: 46,
|
||||||
|
color: 'primary'
|
||||||
|
},
|
||||||
41: {
|
41: {
|
||||||
key: 41,
|
key: 41,
|
||||||
text: 'Novita',
|
text: 'Novita',
|
||||||
|
@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||||
|
{ key: 46, text: 'Replicate', value: 46, color: 'blue' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
Loading…
Reference in New Issue
Block a user