diff --git a/common/render/render.go b/common/render/render.go index 646b3777..e565c0b7 100644 --- a/common/render/render.go +++ b/common/render/render.go @@ -3,9 +3,10 @@ package render import ( "encoding/json" "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "strings" ) func StringData(c *gin.Context, str string) { diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go index 7d73303b..47c2a882 100644 --- a/relay/adaptor/openai/helper.go +++ b/relay/adaptor/openai/helper.go @@ -2,15 +2,16 @@ package openai import ( "fmt" + "strings" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/model" - "strings" ) -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { +func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { usage := &model.Usage{} usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.CompletionTokens = CountTokenText(responseText, modelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage } diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go index 34e42de5..ea91eb73 100644 --- a/relay/adaptor/replicate/adaptor.go +++ b/relay/adaptor/replicate/adaptor.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "slices" + "strings" "time" "github.com/gin-gonic/gin" @@ -80,8 +81,57 @@ func convertImageRemixRequest(c *gin.Context) (any, error) { return rawReq.toFluxRemixRequest() } +// ConvertRequest converts the request to the format that the target API expects. func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - return nil, errors.New("not implemented") + if !request.Stream { + // TODO: support non-stream mode + return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") + } + + // Build the prompt from OpenAI messages + var promptBuilder strings.Builder + for _, message := range request.Messages { + switch msgCnt := message.Content.(type) { + case string: + promptBuilder.WriteString(message.Role) + promptBuilder.WriteString(": ") + promptBuilder.WriteString(msgCnt) + promptBuilder.WriteString("\n") + default: + } + } + + replicateRequest := ReplicateChatRequest{ + Input: ChatInput{ + Prompt: promptBuilder.String(), + MaxTokens: request.MaxTokens, + Temperature: 1.0, + TopP: 1.0, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + }, + } + + // Map optional fields + if request.Temperature != nil { + replicateRequest.Input.Temperature = *request.Temperature + } + if request.TopP != nil { + replicateRequest.Input.TopP = *request.TopP + } + if request.PresencePenalty != nil { + replicateRequest.Input.PresencePenalty = *request.PresencePenalty + } + if request.FrequencyPenalty != nil { + replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty + } + if request.MaxTokens > 0 { + replicateRequest.Input.MaxTokens = request.MaxTokens + } else if request.MaxTokens == 0 { + replicateRequest.Input.MaxTokens = 500 + } + + return replicateRequest, nil } func (a *Adaptor) Init(meta *meta.Meta) { @@ -103,7 +153,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me } func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { - logger.Info(c, "send image request to replicate") + logger.Info(c, "send request to replicate") return adaptor.DoRequestHelper(a, c, meta, requestBody) } @@ -112,6 +162,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met case relaymode.ImagesGenerations, relaymode.ImagesEdits: err, usage = ImageHandler(c, resp) + case relaymode.ChatCompletions: + err, usage = ChatHandler(c, resp) default: err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) } diff --git a/relay/adaptor/replicate/chat.go b/relay/adaptor/replicate/chat.go new file mode 100644 index 00000000..4051f85c --- /dev/null +++ b/relay/adaptor/replicate/chat.go @@ -0,0 +1,191 @@ +package replicate + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +func ChatHandler(c *gin.Context, resp *http.Response) ( + srvErr *model.ErrorWithStatusCode, usage *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ChatResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ChatResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + if taskData.URLs.Stream == "" { + return errors.New("stream url is empty") + } + + // request stream url + responseText, err := chatStreamHandler(c, taskData.URLs.Stream) + if err != nil { + return errors.Wrap(err, "chat stream handler") + } + + ctxMeta := meta.GetByContext(c) + usage = openai.ResponseText2Usage(responseText, + ctxMeta.ActualModelName, ctxMeta.PromptTokens) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, usage +} + +const ( + eventPrefix = "event: " + dataPrefix = "data: " + done = "[DONE]" +) + +func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { + // request stream endpoint + streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) + if err != nil { + return "", errors.Wrap(err, "new request to stream") + } + + streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + streamReq.Header.Set("Accept", "text/event-stream") + streamReq.Header.Set("Cache-Control", "no-store") + + resp, err := http.DefaultClient.Do(streamReq) + if err != nil { + return "", errors.Wrap(err, "do request to stream") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(resp.Body) + return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + doneRendered := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Handle comments starting with ':' + if strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE fields + if strings.HasPrefix(line, eventPrefix) { + event := strings.TrimSpace(line[len(eventPrefix):]) + var data string + // Read the following lines to get data and id + for scanner.Scan() { + nextLine := scanner.Text() + if nextLine == "" { + break + } + if strings.HasPrefix(nextLine, dataPrefix) { + data = nextLine[len(dataPrefix):] + } else if strings.HasPrefix(nextLine, "id:") { + // id = strings.TrimSpace(nextLine[len("id:"):]) + } + } + + if event == "output" { + render.StringData(c, data) + responseText += data + } else if event == "done" { + render.Done(c) + doneRendered = true + break + } + } + } + + if err := scanner.Err(); err != nil { + return "", errors.Wrap(err, "scan stream") + } + + if !doneRendered { + render.Done(c) + } + + return responseText, nil +} diff --git a/relay/adaptor/replicate/constant.go b/relay/adaptor/replicate/constant.go index 157cc045..989142c9 100644 --- a/relay/adaptor/replicate/constant.go +++ b/relay/adaptor/replicate/constant.go @@ -33,24 +33,24 @@ var ModelList = []string{ // ------------------------------------- // language model // ------------------------------------- - // "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor - // "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor - // "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor - // "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor - // "meta/llama-2-13b", // TODO: implement the adaptor - // "meta/llama-2-13b-chat", // TODO: implement the adaptor - // "meta/llama-2-70b", // TODO: implement the adaptor - // "meta/llama-2-70b-chat", // TODO: implement the adaptor - // "meta/llama-2-7b", // TODO: implement the adaptor - // "meta/llama-2-7b-chat", // TODO: implement the adaptor - // "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor - // "meta/meta-llama-3-70b", // TODO: implement the adaptor - // "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor - // "meta/meta-llama-3-8b", // TODO: implement the adaptor - // "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor - // "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor - // "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor - // "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor + "ibm-granite/granite-20b-code-instruct-8k", + "ibm-granite/granite-3.0-2b-instruct", + "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k", + "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3.1-405b-instruct", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct", + "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1", + "mistralai/mixtral-8x7b-instruct-v0.1", // ------------------------------------- // video model // ------------------------------------- diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go index 93bdcda2..04fe5277 100644 --- a/relay/adaptor/replicate/model.go +++ b/relay/adaptor/replicate/model.go @@ -227,3 +227,51 @@ type FluxURLs struct { Get string `json:"get"` Cancel string `json:"cancel"` } + +type ReplicateChatRequest struct { + Input ChatInput `json:"input" form:"input" binding:"required"` +} + +// ChatInput is input of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema +type ChatInput struct { + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + MinTokens int `json:"min_tokens"` + Temperature float64 `json:"temperature"` + SystemPrompt string `json:"system_prompt"` + StopSequences string `json:"stop_sequences"` + PromptTemplate string `json:"prompt_template"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` +} + +// ChatResponse is response of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json +type ChatResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input ChatInput `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output []string `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs ChatResponseUrl `json:"urls"` + Version string `json:"version"` +} + +// ChatResponseUrl is task urls of ChatResponse +type ChatResponseUrl struct { + Stream string `json:"stream"` + Get string `json:"get"` + Cancel string `json:"cancel"` +} diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 9b799be7..2afa9717 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -245,6 +245,25 @@ var ModelRatio = map[string]float64{ "stability-ai/stable-diffusion-3.5-large": 0.065 * USD, "stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, "stability-ai/stable-diffusion-3.5-medium": 0.035 * USD, + // replicate chat models + "ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD, + "ibm-granite/granite-3.0-2b-instruct": 0.030 * USD, + "ibm-granite/granite-3.0-8b-instruct": 0.050 * USD, + "ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, + "meta/llama-2-13b": 0.100 * USD, + "meta/llama-2-13b-chat": 0.100 * USD, + "meta/llama-2-70b": 0.650 * USD, + "meta/llama-2-70b-chat": 0.650 * USD, + "meta/llama-2-7b": 0.050 * USD, + "meta/llama-2-7b-chat": 0.050 * USD, + "meta/meta-llama-3.1-405b-instruct": 9.500 * USD, + "meta/meta-llama-3-70b": 0.650 * USD, + "meta/meta-llama-3-70b-instruct": 0.650 * USD, + "meta/meta-llama-3-8b": 0.050 * USD, + "meta/meta-llama-3-8b-instruct": 0.050 * USD, + "mistralai/mistral-7b-instruct-v0.2": 0.050 * USD, + "mistralai/mistral-7b-v0.1": 0.050 * USD, + "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, } var CompletionRatio = map[string]float64{ @@ -402,6 +421,7 @@ func GetCompletionRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "deepseek-") { return 2 } + switch name { case "llama2-70b-4096": return 0.8 / 0.64 @@ -417,6 +437,35 @@ func GetCompletionRatio(name string, channelType int) float64 { return 5 case "grok-beta": return 3 + // Replicate Models + // https://replicate.com/pricing + case "ibm-granite/granite-20b-code-instruct-8k": + return 5 + case "ibm-granite/granite-3.0-2b-instruct": + return 8.333333333333334 + case "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k": + return 5 + case "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct": + return 5 + case "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct": + return 2.750 / 0.650 // ≈4.230769 + case "meta/meta-llama-3.1-405b-instruct": + return 1 + case "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1": + return 5 + case "mistralai/mixtral-8x7b-instruct-v0.1": + return 1.000 / 0.300 // ≈3.333333 } + return 1 } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index d859532e..86a614cb 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -149,14 +149,20 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { } return true } - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK && + // replicate return 201 to create a task + resp.StatusCode != http.StatusCreated { return true } if meta.ChannelType == channeltype.DeepL { // skip stream check for deepl return false } - if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + + if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && + // Even if stream mode is enabled, replicate will first return a task info in JSON format, + // requiring the client to request the stream endpoint in the task info + meta.ChannelType != channeltype.Replicate { return true } return false diff --git a/relay/controller/image.go b/relay/controller/image.go index 16988491..0c21c543 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -24,7 +24,7 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { +func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { @@ -67,7 +67,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)