mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
feat: support llm chat on replicate
This commit is contained in:
parent
4dd2b9dcb8
commit
48e8b6b5c0
@ -3,9 +3,10 @@ package render
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StringData(c *gin.Context, str string) {
|
||||
|
@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||
strings.Contains(lowerMessage, "credit") ||
|
||||
strings.Contains(lowerMessage, "balance") ||
|
||||
strings.Contains(lowerMessage, "permission denied") ||
|
||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||
strings.Contains(lowerMessage, "已欠费") {
|
||||
return true
|
||||
}
|
||||
|
@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
NumPredict: request.MaxTokens,
|
||||
NumCtx: request.NumCtx,
|
||||
NumPredict: request.MaxTokens,
|
||||
NumCtx: request.NumCtx,
|
||||
},
|
||||
Stream: request.Stream,
|
||||
}
|
||||
@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if strings.HasPrefix(data, "}") {
|
||||
data = strings.TrimPrefix(data, "}") + "}"
|
||||
data = strings.TrimPrefix(data, "}") + "}"
|
||||
}
|
||||
|
||||
var ollamaResponse ChatResponse
|
||||
|
@ -2,15 +2,16 @@ package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
|
||||
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
|
||||
usage := &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = CountTokenText(responseText, modeName)
|
||||
usage.CompletionTokens = CountTokenText(responseText, modelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -39,7 +40,55 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
if !request.Stream {
|
||||
// TODO: support non-stream mode
|
||||
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
|
||||
}
|
||||
|
||||
// Build the prompt from OpenAI messages
|
||||
var promptBuilder strings.Builder
|
||||
for _, message := range request.Messages {
|
||||
switch msgCnt := message.Content.(type) {
|
||||
case string:
|
||||
promptBuilder.WriteString(message.Role)
|
||||
promptBuilder.WriteString(": ")
|
||||
promptBuilder.WriteString(msgCnt)
|
||||
promptBuilder.WriteString("\n")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
replicateRequest := ReplicateChatRequest{
|
||||
Input: ChatInput{
|
||||
Prompt: promptBuilder.String(),
|
||||
MaxTokens: request.MaxTokens,
|
||||
Temperature: 1.0,
|
||||
TopP: 1.0,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Map optional fields
|
||||
if request.Temperature != nil {
|
||||
replicateRequest.Input.Temperature = *request.Temperature
|
||||
}
|
||||
if request.TopP != nil {
|
||||
replicateRequest.Input.TopP = *request.TopP
|
||||
}
|
||||
if request.PresencePenalty != nil {
|
||||
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
|
||||
}
|
||||
if request.FrequencyPenalty != nil {
|
||||
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
|
||||
}
|
||||
if request.MaxTokens > 0 {
|
||||
replicateRequest.Input.MaxTokens = request.MaxTokens
|
||||
} else if request.MaxTokens == 0 {
|
||||
replicateRequest.Input.MaxTokens = 500
|
||||
}
|
||||
|
||||
return replicateRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
@ -61,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
logger.Info(c, "send image request to replicate")
|
||||
logger.Info(c, "send request to replicate")
|
||||
return adaptor.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
||||
@ -69,6 +118,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err, usage = ImageHandler(c, resp)
|
||||
case relaymode.ChatCompletions:
|
||||
err, usage = ChatHandler(c, resp)
|
||||
default:
|
||||
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
|
||||
}
|
||||
|
191
relay/adaptor/replicate/chat.go
Normal file
191
relay/adaptor/replicate/chat.go
Normal file
@ -0,0 +1,191 @@
|
||||
package replicate
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ChatHandler(c *gin.Context, resp *http.Response) (
|
||||
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return openai.ErrorWrapper(
|
||||
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
||||
"bad_status_code", http.StatusInternalServerError),
|
||||
nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
respData := new(ChatResponse)
|
||||
if err = json.Unmarshal(respBody, respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
for {
|
||||
err = func() error {
|
||||
// get task
|
||||
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
http.MethodGet, respData.URLs.Get, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "new request")
|
||||
}
|
||||
|
||||
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
taskResp, err := http.DefaultClient.Do(taskReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get task")
|
||||
}
|
||||
defer taskResp.Body.Close()
|
||||
|
||||
if taskResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(taskResp.Body)
|
||||
return errors.Errorf("bad status code [%d]%s",
|
||||
taskResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
taskBody, err := io.ReadAll(taskResp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read task response")
|
||||
}
|
||||
|
||||
taskData := new(ChatResponse)
|
||||
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
||||
return errors.Wrap(err, "decode task response")
|
||||
}
|
||||
|
||||
switch taskData.Status {
|
||||
case "succeeded":
|
||||
case "failed", "canceled":
|
||||
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
||||
default:
|
||||
time.Sleep(time.Second * 3)
|
||||
return errNextLoop
|
||||
}
|
||||
|
||||
if taskData.URLs.Stream == "" {
|
||||
return errors.New("stream url is empty")
|
||||
}
|
||||
|
||||
// request stream url
|
||||
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "chat stream handler")
|
||||
}
|
||||
|
||||
ctxMeta := meta.GetByContext(c)
|
||||
usage = openai.ResponseText2Usage(responseText,
|
||||
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
if errors.Is(err, errNextLoop) {
|
||||
continue
|
||||
}
|
||||
|
||||
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
const (
|
||||
eventPrefix = "event: "
|
||||
dataPrefix = "data: "
|
||||
done = "[DONE]"
|
||||
)
|
||||
|
||||
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
||||
// request stream endpoint
|
||||
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "new request to stream")
|
||||
}
|
||||
|
||||
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
||||
streamReq.Header.Set("Accept", "text/event-stream")
|
||||
streamReq.Header.Set("Cache-Control", "no-store")
|
||||
|
||||
resp, err := http.DefaultClient.Do(streamReq)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "do request to stream")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(resp.Body)
|
||||
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
doneRendered := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle comments starting with ':'
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse SSE fields
|
||||
if strings.HasPrefix(line, eventPrefix) {
|
||||
event := strings.TrimSpace(line[len(eventPrefix):])
|
||||
var data string
|
||||
// Read the following lines to get data and id
|
||||
for scanner.Scan() {
|
||||
nextLine := scanner.Text()
|
||||
if nextLine == "" {
|
||||
break
|
||||
}
|
||||
if strings.HasPrefix(nextLine, dataPrefix) {
|
||||
data = nextLine[len(dataPrefix):]
|
||||
} else if strings.HasPrefix(nextLine, "id:") {
|
||||
// id = strings.TrimSpace(nextLine[len("id:"):])
|
||||
}
|
||||
}
|
||||
|
||||
if event == "output" {
|
||||
render.StringData(c, data)
|
||||
responseText += data
|
||||
} else if event == "done" {
|
||||
render.Done(c)
|
||||
doneRendered = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", errors.Wrap(err, "scan stream")
|
||||
}
|
||||
|
||||
if !doneRendered {
|
||||
render.Done(c)
|
||||
}
|
||||
|
||||
return responseText, nil
|
||||
}
|
@ -33,24 +33,24 @@ var ModelList = []string{
|
||||
// -------------------------------------
|
||||
// language model
|
||||
// -------------------------------------
|
||||
// "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor
|
||||
// "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor
|
||||
// "meta/llama-2-13b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-13b-chat", // TODO: implement the adaptor
|
||||
// "meta/llama-2-70b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-70b-chat", // TODO: implement the adaptor
|
||||
// "meta/llama-2-7b", // TODO: implement the adaptor
|
||||
// "meta/llama-2-7b-chat", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-70b", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-8b", // TODO: implement the adaptor
|
||||
// "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor
|
||||
// "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor
|
||||
// "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor
|
||||
// "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor
|
||||
"ibm-granite/granite-20b-code-instruct-8k",
|
||||
"ibm-granite/granite-3.0-2b-instruct",
|
||||
"ibm-granite/granite-3.0-8b-instruct",
|
||||
"ibm-granite/granite-8b-code-instruct-128k",
|
||||
"meta/llama-2-13b",
|
||||
"meta/llama-2-13b-chat",
|
||||
"meta/llama-2-70b",
|
||||
"meta/llama-2-70b-chat",
|
||||
"meta/llama-2-7b",
|
||||
"meta/llama-2-7b-chat",
|
||||
"meta/meta-llama-3.1-405b-instruct",
|
||||
"meta/meta-llama-3-70b",
|
||||
"meta/meta-llama-3-70b-instruct",
|
||||
"meta/meta-llama-3-8b",
|
||||
"meta/meta-llama-3-8b-instruct",
|
||||
"mistralai/mistral-7b-instruct-v0.2",
|
||||
"mistralai/mistral-7b-v0.1",
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
// -------------------------------------
|
||||
// video model
|
||||
// -------------------------------------
|
||||
|
@ -109,3 +109,51 @@ type FluxURLs struct {
|
||||
Get string `json:"get"`
|
||||
Cancel string `json:"cancel"`
|
||||
}
|
||||
|
||||
type ReplicateChatRequest struct {
|
||||
Input ChatInput `json:"input" form:"input" binding:"required"`
|
||||
}
|
||||
|
||||
// ChatInput is input of ChatByReplicateRequest
|
||||
//
|
||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
|
||||
type ChatInput struct {
|
||||
TopK int `json:"top_k"`
|
||||
TopP float64 `json:"top_p"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MinTokens int `json:"min_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
StopSequences string `json:"stop_sequences"`
|
||||
PromptTemplate string `json:"prompt_template"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
}
|
||||
|
||||
// ChatResponse is response of ChatByReplicateRequest
|
||||
//
|
||||
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
|
||||
type ChatResponse struct {
|
||||
CompletedAt time.Time `json:"completed_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DataRemoved bool `json:"data_removed"`
|
||||
Error string `json:"error"`
|
||||
ID string `json:"id"`
|
||||
Input ChatInput `json:"input"`
|
||||
Logs string `json:"logs"`
|
||||
Metrics FluxMetrics `json:"metrics"`
|
||||
// Output could be `string` or `[]string`
|
||||
Output []string `json:"output"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Status string `json:"status"`
|
||||
URLs ChatResponseUrl `json:"urls"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ChatResponseUrl is task urls of ChatResponse
|
||||
type ChatResponseUrl struct {
|
||||
Stream string `json:"stream"`
|
||||
Get string `json:"get"`
|
||||
Cancel string `json:"cancel"`
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
|
@ -236,6 +236,25 @@ var ModelRatio = map[string]float64{
|
||||
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
|
||||
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
|
||||
// replicate chat models
|
||||
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
|
||||
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
|
||||
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
|
||||
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
|
||||
"meta/llama-2-13b": 0.100 * USD,
|
||||
"meta/llama-2-13b-chat": 0.100 * USD,
|
||||
"meta/llama-2-70b": 0.650 * USD,
|
||||
"meta/llama-2-70b-chat": 0.650 * USD,
|
||||
"meta/llama-2-7b": 0.050 * USD,
|
||||
"meta/llama-2-7b-chat": 0.050 * USD,
|
||||
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
|
||||
"meta/meta-llama-3-70b": 0.650 * USD,
|
||||
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
|
||||
"meta/meta-llama-3-8b": 0.050 * USD,
|
||||
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
|
||||
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
|
||||
"mistralai/mistral-7b-v0.1": 0.050 * USD,
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
|
||||
}
|
||||
|
||||
var CompletionRatio = map[string]float64{
|
||||
@ -387,6 +406,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
||||
if strings.HasPrefix(name, "deepseek-") {
|
||||
return 2
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "llama2-70b-4096":
|
||||
return 0.8 / 0.64
|
||||
@ -402,6 +422,35 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
||||
return 5
|
||||
case "grok-beta":
|
||||
return 3
|
||||
// Replicate Models
|
||||
// https://replicate.com/pricing
|
||||
case "ibm-granite/granite-20b-code-instruct-8k":
|
||||
return 5
|
||||
case "ibm-granite/granite-3.0-2b-instruct":
|
||||
return 8.333333333333334
|
||||
case "ibm-granite/granite-3.0-8b-instruct",
|
||||
"ibm-granite/granite-8b-code-instruct-128k":
|
||||
return 5
|
||||
case "meta/llama-2-13b",
|
||||
"meta/llama-2-13b-chat",
|
||||
"meta/llama-2-7b",
|
||||
"meta/llama-2-7b-chat",
|
||||
"meta/meta-llama-3-8b",
|
||||
"meta/meta-llama-3-8b-instruct":
|
||||
return 5
|
||||
case "meta/llama-2-70b",
|
||||
"meta/llama-2-70b-chat",
|
||||
"meta/meta-llama-3-70b",
|
||||
"meta/meta-llama-3-70b-instruct":
|
||||
return 2.750 / 0.650 // ≈4.230769
|
||||
case "meta/meta-llama-3.1-405b-instruct":
|
||||
return 1
|
||||
case "mistralai/mistral-7b-instruct-v0.2",
|
||||
"mistralai/mistral-7b-v0.1":
|
||||
return 5
|
||||
case "mistralai/mixtral-8x7b-instruct-v0.1":
|
||||
return 1.000 / 0.300 // ≈3.333333
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
@ -147,14 +147,20 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode != http.StatusOK &&
|
||||
// replicate return 201 to create a task
|
||||
resp.StatusCode != http.StatusCreated {
|
||||
return true
|
||||
}
|
||||
if meta.ChannelType == channeltype.DeepL {
|
||||
// skip stream check for deepl
|
||||
return false
|
||||
}
|
||||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
|
||||
|
||||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
|
||||
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
|
||||
// requiring the client to request the stream endpoint in the task info
|
||||
meta.ChannelType != channeltype.Replicate {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -22,7 +22,7 @@ import (
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
||||
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
|
||||
imageRequest := &relaymodel.ImageRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
// check prompt length
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
|
Loading…
Reference in New Issue
Block a user