mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
feat: support llm chat on replicate
This commit is contained in:
parent
4dd2b9dcb8
commit
48e8b6b5c0
@ -3,9 +3,10 @@ package render
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func StringData(c *gin.Context, str string) {
|
func StringData(c *gin.Context, str string) {
|
||||||
|
@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
|||||||
strings.Contains(lowerMessage, "credit") ||
|
strings.Contains(lowerMessage, "credit") ||
|
||||||
strings.Contains(lowerMessage, "balance") ||
|
strings.Contains(lowerMessage, "balance") ||
|
||||||
strings.Contains(lowerMessage, "permission denied") ||
|
strings.Contains(lowerMessage, "permission denied") ||
|
||||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||||
strings.Contains(lowerMessage, "已欠费") {
|
strings.Contains(lowerMessage, "已欠费") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
FrequencyPenalty: request.FrequencyPenalty,
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
PresencePenalty: request.PresencePenalty,
|
PresencePenalty: request.PresencePenalty,
|
||||||
NumPredict: request.MaxTokens,
|
NumPredict: request.MaxTokens,
|
||||||
NumCtx: request.NumCtx,
|
NumCtx: request.NumCtx,
|
||||||
},
|
},
|
||||||
Stream: request.Stream,
|
Stream: request.Stream,
|
||||||
}
|
}
|
||||||
@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
if strings.HasPrefix(data, "}") {
|
if strings.HasPrefix(data, "}") {
|
||||||
data = strings.TrimPrefix(data, "}") + "}"
|
data = strings.TrimPrefix(data, "}") + "}"
|
||||||
}
|
}
|
||||||
|
|
||||||
var ollamaResponse ChatResponse
|
var ollamaResponse ChatResponse
|
||||||
|
@ -2,15 +2,16 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
|
||||||
usage := &model.Usage{}
|
usage := &model.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
usage.CompletionTokens = CountTokenText(responseText, modelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -39,7 +40,55 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
if !request.Stream {
|
||||||
|
// TODO: support non-stream mode
|
||||||
|
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the prompt from OpenAI messages
|
||||||
|
var promptBuilder strings.Builder
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
switch msgCnt := message.Content.(type) {
|
||||||
|
case string:
|
||||||
|
promptBuilder.WriteString(message.Role)
|
||||||
|
promptBuilder.WriteString(": ")
|
||||||
|
promptBuilder.WriteString(msgCnt)
|
||||||
|
promptBuilder.WriteString("\n")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
replicateRequest := ReplicateChatRequest{
|
||||||
|
Input: ChatInput{
|
||||||
|
Prompt: promptBuilder.String(),
|
||||||
|
MaxTokens: request.MaxTokens,
|
||||||
|
Temperature: 1.0,
|
||||||
|
TopP: 1.0,
|
||||||
|
PresencePenalty: 0.0,
|
||||||
|
FrequencyPenalty: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map optional fields
|
||||||
|
if request.Temperature != nil {
|
||||||
|
replicateRequest.Input.Temperature = *request.Temperature
|
||||||
|
}
|
||||||
|
if request.TopP != nil {
|
||||||
|
replicateRequest.Input.TopP = *request.TopP
|
||||||
|
}
|
||||||
|
if request.PresencePenalty != nil {
|
||||||
|
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
|
||||||
|
}
|
||||||
|
if request.FrequencyPenalty != nil {
|
||||||
|
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
|
||||||
|
}
|
||||||
|
if request.MaxTokens > 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = request.MaxTokens
|
||||||
|
} else if request.MaxTokens == 0 {
|
||||||
|
replicateRequest.Input.MaxTokens = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return replicateRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
@ -61,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
logger.Info(c, "send image request to replicate")
|
logger.Info(c, "send request to replicate")
|
||||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,6 +118,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
|||||||
switch meta.Mode {
|
switch meta.Mode {
|
||||||
case relaymode.ImagesGenerations:
|
case relaymode.ImagesGenerations:
|
||||||
err, usage = ImageHandler(c, resp)
|
err, usage = ImageHandler(c, resp)
|
||||||
|
case relaymode.ChatCompletions:
|
||||||
|
err, usage = ChatHandler(c, resp)
|
||||||
default:
|
default:
|
||||||
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
191
relay/adaptor/replicate/chat.go
Normal file
191
relay/adaptor/replicate/chat.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/render"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChatHandler(c *gin.Context, resp *http.Response) (
|
||||||
|
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return openai.ErrorWrapper(
|
||||||
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||||
|
"bad_status_code", http.StatusInternalServerError),
|
||||||
|
nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
respData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
err = func() error {
|
||||||
|
// get task
|
||||||
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||||
|
http.MethodGet, respData.URLs.Get, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "new request")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "get task")
|
||||||
|
}
|
||||||
|
defer taskResp.Body.Close()
|
||||||
|
|
||||||
|
if taskResp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(taskResp.Body)
|
||||||
|
return errors.Errorf("bad status code [%d]%s",
|
||||||
|
taskResp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
taskBody, err := io.ReadAll(taskResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskData := new(ChatResponse)
|
||||||
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||||
|
return errors.Wrap(err, "decode task response")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch taskData.Status {
|
||||||
|
case "succeeded":
|
||||||
|
case "failed", "canceled":
|
||||||
|
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||||
|
default:
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
return errNextLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskData.URLs.Stream == "" {
|
||||||
|
return errors.New("stream url is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// request stream url
|
||||||
|
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "chat stream handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxMeta := meta.GetByContext(c)
|
||||||
|
usage = openai.ResponseText2Usage(responseText,
|
||||||
|
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNextLoop) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventPrefix = "event: "
|
||||||
|
dataPrefix = "data: "
|
||||||
|
done = "[DONE]"
|
||||||
|
)
|
||||||
|
|
||||||
|
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
||||||
|
// request stream endpoint
|
||||||
|
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "new request to stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||||
|
streamReq.Header.Set("Accept", "text/event-stream")
|
||||||
|
streamReq.Header.Set("Cache-Control", "no-store")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(streamReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "do request to stream")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
payload, _ := io.ReadAll(resp.Body)
|
||||||
|
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
common.SetEventStreamHeaders(c)
|
||||||
|
doneRendered := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle comments starting with ':'
|
||||||
|
if strings.HasPrefix(line, ":") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE fields
|
||||||
|
if strings.HasPrefix(line, eventPrefix) {
|
||||||
|
event := strings.TrimSpace(line[len(eventPrefix):])
|
||||||
|
var data string
|
||||||
|
// Read the following lines to get data and id
|
||||||
|
for scanner.Scan() {
|
||||||
|
nextLine := scanner.Text()
|
||||||
|
if nextLine == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(nextLine, dataPrefix) {
|
||||||
|
data = nextLine[len(dataPrefix):]
|
||||||
|
} else if strings.HasPrefix(nextLine, "id:") {
|
||||||
|
// id = strings.TrimSpace(nextLine[len("id:"):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if event == "output" {
|
||||||
|
render.StringData(c, data)
|
||||||
|
responseText += data
|
||||||
|
} else if event == "done" {
|
||||||
|
render.Done(c)
|
||||||
|
doneRendered = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return "", errors.Wrap(err, "scan stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !doneRendered {
|
||||||
|
render.Done(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseText, nil
|
||||||
|
}
|
@ -33,24 +33,24 @@ var ModelList = []string{
|
|||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
// language model
|
// language model
|
||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
// "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor
|
"ibm-granite/granite-20b-code-instruct-8k",
|
||||||
// "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor
|
"ibm-granite/granite-3.0-2b-instruct",
|
||||||
// "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor
|
"ibm-granite/granite-3.0-8b-instruct",
|
||||||
// "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor
|
"ibm-granite/granite-8b-code-instruct-128k",
|
||||||
// "meta/llama-2-13b", // TODO: implement the adaptor
|
"meta/llama-2-13b",
|
||||||
// "meta/llama-2-13b-chat", // TODO: implement the adaptor
|
"meta/llama-2-13b-chat",
|
||||||
// "meta/llama-2-70b", // TODO: implement the adaptor
|
"meta/llama-2-70b",
|
||||||
// "meta/llama-2-70b-chat", // TODO: implement the adaptor
|
"meta/llama-2-70b-chat",
|
||||||
// "meta/llama-2-7b", // TODO: implement the adaptor
|
"meta/llama-2-7b",
|
||||||
// "meta/llama-2-7b-chat", // TODO: implement the adaptor
|
"meta/llama-2-7b-chat",
|
||||||
// "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor
|
"meta/meta-llama-3.1-405b-instruct",
|
||||||
// "meta/meta-llama-3-70b", // TODO: implement the adaptor
|
"meta/meta-llama-3-70b",
|
||||||
// "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor
|
"meta/meta-llama-3-70b-instruct",
|
||||||
// "meta/meta-llama-3-8b", // TODO: implement the adaptor
|
"meta/meta-llama-3-8b",
|
||||||
// "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor
|
"meta/meta-llama-3-8b-instruct",
|
||||||
// "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor
|
"mistralai/mistral-7b-instruct-v0.2",
|
||||||
// "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor
|
"mistralai/mistral-7b-v0.1",
|
||||||
// "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor
|
"mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
// video model
|
// video model
|
||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
|
@ -109,3 +109,51 @@ type FluxURLs struct {
|
|||||||
Get string `json:"get"`
|
Get string `json:"get"`
|
||||||
Cancel string `json:"cancel"`
|
Cancel string `json:"cancel"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ReplicateChatRequest struct {
|
||||||
|
Input ChatInput `json:"input" form:"input" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatInput is input of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
|
||||||
|
type ChatInput struct {
|
||||||
|
TopK int `json:"top_k"`
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
MinTokens int `json:"min_tokens"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
SystemPrompt string `json:"system_prompt"`
|
||||||
|
StopSequences string `json:"stop_sequences"`
|
||||||
|
PromptTemplate string `json:"prompt_template"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is response of ChatByReplicateRequest
|
||||||
|
//
|
||||||
|
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
|
||||||
|
type ChatResponse struct {
|
||||||
|
CompletedAt time.Time `json:"completed_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DataRemoved bool `json:"data_removed"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Input ChatInput `json:"input"`
|
||||||
|
Logs string `json:"logs"`
|
||||||
|
Metrics FluxMetrics `json:"metrics"`
|
||||||
|
// Output could be `string` or `[]string`
|
||||||
|
Output []string `json:"output"`
|
||||||
|
StartedAt time.Time `json:"started_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
URLs ChatResponseUrl `json:"urls"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponseUrl is task urls of ChatResponse
|
||||||
|
type ChatResponseUrl struct {
|
||||||
|
Stream string `json:"stream"`
|
||||||
|
Get string `json:"get"`
|
||||||
|
Cancel string `json:"cancel"`
|
||||||
|
}
|
||||||
|
@ -236,6 +236,25 @@ var ModelRatio = map[string]float64{
|
|||||||
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||||
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||||
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||||
|
// replicate chat models
|
||||||
|
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
|
||||||
|
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
|
||||||
|
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
|
||||||
|
"meta/llama-2-13b": 0.100 * USD,
|
||||||
|
"meta/llama-2-13b-chat": 0.100 * USD,
|
||||||
|
"meta/llama-2-70b": 0.650 * USD,
|
||||||
|
"meta/llama-2-70b-chat": 0.650 * USD,
|
||||||
|
"meta/llama-2-7b": 0.050 * USD,
|
||||||
|
"meta/llama-2-7b-chat": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
|
||||||
|
"meta/meta-llama-3-70b": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
|
||||||
|
"meta/meta-llama-3-8b": 0.050 * USD,
|
||||||
|
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
|
||||||
|
"mistralai/mistral-7b-v0.1": 0.050 * USD,
|
||||||
|
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{
|
var CompletionRatio = map[string]float64{
|
||||||
@ -387,6 +406,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
if strings.HasPrefix(name, "deepseek-") {
|
if strings.HasPrefix(name, "deepseek-") {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
case "llama2-70b-4096":
|
case "llama2-70b-4096":
|
||||||
return 0.8 / 0.64
|
return 0.8 / 0.64
|
||||||
@ -402,6 +422,35 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
|||||||
return 5
|
return 5
|
||||||
case "grok-beta":
|
case "grok-beta":
|
||||||
return 3
|
return 3
|
||||||
|
// Replicate Models
|
||||||
|
// https://replicate.com/pricing
|
||||||
|
case "ibm-granite/granite-20b-code-instruct-8k":
|
||||||
|
return 5
|
||||||
|
case "ibm-granite/granite-3.0-2b-instruct":
|
||||||
|
return 8.333333333333334
|
||||||
|
case "ibm-granite/granite-3.0-8b-instruct",
|
||||||
|
"ibm-granite/granite-8b-code-instruct-128k":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-13b",
|
||||||
|
"meta/llama-2-13b-chat",
|
||||||
|
"meta/llama-2-7b",
|
||||||
|
"meta/llama-2-7b-chat",
|
||||||
|
"meta/meta-llama-3-8b",
|
||||||
|
"meta/meta-llama-3-8b-instruct":
|
||||||
|
return 5
|
||||||
|
case "meta/llama-2-70b",
|
||||||
|
"meta/llama-2-70b-chat",
|
||||||
|
"meta/meta-llama-3-70b",
|
||||||
|
"meta/meta-llama-3-70b-instruct":
|
||||||
|
return 2.750 / 0.650 // ≈4.230769
|
||||||
|
case "meta/meta-llama-3.1-405b-instruct":
|
||||||
|
return 1
|
||||||
|
case "mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
"mistralai/mistral-7b-v0.1":
|
||||||
|
return 5
|
||||||
|
case "mistralai/mixtral-8x7b-instruct-v0.1":
|
||||||
|
return 1.000 / 0.300 // ≈3.333333
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
@ -147,14 +147,20 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK &&
|
||||||
|
// replicate return 201 to create a task
|
||||||
|
resp.StatusCode != http.StatusCreated {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if meta.ChannelType == channeltype.DeepL {
|
if meta.ChannelType == channeltype.DeepL {
|
||||||
// skip stream check for deepl
|
// skip stream check for deepl
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
|
||||||
|
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
|
||||||
|
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
|
||||||
|
// requiring the client to request the stream endpoint in the task info
|
||||||
|
meta.ChannelType != channeltype.Replicate {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
@ -22,7 +22,7 @@ import (
|
|||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
|
||||||
imageRequest := &relaymodel.ImageRequest{}
|
imageRequest := &relaymodel.ImageRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||||
// check prompt length
|
// check prompt length
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
|
Loading…
Reference in New Issue
Block a user