mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
* feat: add Replicate adaptor and integrate into channel and API types * feat: support llm chat on replicate
192 lines
4.8 KiB
Go
192 lines
4.8 KiB
Go
package replicate
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/pkg/errors"
|
|
"github.com/songquanpeng/one-api/common"
|
|
"github.com/songquanpeng/one-api/common/render"
|
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
|
"github.com/songquanpeng/one-api/relay/meta"
|
|
"github.com/songquanpeng/one-api/relay/model"
|
|
)
|
|
|
|
func ChatHandler(c *gin.Context, resp *http.Response) (
|
|
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
|
|
if resp.StatusCode != http.StatusCreated {
|
|
payload, _ := io.ReadAll(resp.Body)
|
|
return openai.ErrorWrapper(
|
|
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
|
|
"bad_status_code", http.StatusInternalServerError),
|
|
nil
|
|
}
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
respData := new(ChatResponse)
|
|
if err = json.Unmarshal(respBody, respData); err != nil {
|
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
for {
|
|
err = func() error {
|
|
// get task
|
|
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
|
|
http.MethodGet, respData.URLs.Get, nil)
|
|
if err != nil {
|
|
return errors.Wrap(err, "new request")
|
|
}
|
|
|
|
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
|
taskResp, err := http.DefaultClient.Do(taskReq)
|
|
if err != nil {
|
|
return errors.Wrap(err, "get task")
|
|
}
|
|
defer taskResp.Body.Close()
|
|
|
|
if taskResp.StatusCode != http.StatusOK {
|
|
payload, _ := io.ReadAll(taskResp.Body)
|
|
return errors.Errorf("bad status code [%d]%s",
|
|
taskResp.StatusCode, string(payload))
|
|
}
|
|
|
|
taskBody, err := io.ReadAll(taskResp.Body)
|
|
if err != nil {
|
|
return errors.Wrap(err, "read task response")
|
|
}
|
|
|
|
taskData := new(ChatResponse)
|
|
if err = json.Unmarshal(taskBody, taskData); err != nil {
|
|
return errors.Wrap(err, "decode task response")
|
|
}
|
|
|
|
switch taskData.Status {
|
|
case "succeeded":
|
|
case "failed", "canceled":
|
|
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
|
|
default:
|
|
time.Sleep(time.Second * 3)
|
|
return errNextLoop
|
|
}
|
|
|
|
if taskData.URLs.Stream == "" {
|
|
return errors.New("stream url is empty")
|
|
}
|
|
|
|
// request stream url
|
|
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
|
|
if err != nil {
|
|
return errors.Wrap(err, "chat stream handler")
|
|
}
|
|
|
|
ctxMeta := meta.GetByContext(c)
|
|
usage = openai.ResponseText2Usage(responseText,
|
|
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
|
|
return nil
|
|
}()
|
|
if err != nil {
|
|
if errors.Is(err, errNextLoop) {
|
|
continue
|
|
}
|
|
|
|
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
return nil, usage
|
|
}
|
|
|
|
const (
|
|
eventPrefix = "event: "
|
|
dataPrefix = "data: "
|
|
done = "[DONE]"
|
|
)
|
|
|
|
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
|
|
// request stream endpoint
|
|
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "new request to stream")
|
|
}
|
|
|
|
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
|
|
streamReq.Header.Set("Accept", "text/event-stream")
|
|
streamReq.Header.Set("Cache-Control", "no-store")
|
|
|
|
resp, err := http.DefaultClient.Do(streamReq)
|
|
if err != nil {
|
|
return "", errors.Wrap(err, "do request to stream")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
payload, _ := io.ReadAll(resp.Body)
|
|
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
|
|
}
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Split(bufio.ScanLines)
|
|
|
|
common.SetEventStreamHeaders(c)
|
|
doneRendered := false
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
// Handle comments starting with ':'
|
|
if strings.HasPrefix(line, ":") {
|
|
continue
|
|
}
|
|
|
|
// Parse SSE fields
|
|
if strings.HasPrefix(line, eventPrefix) {
|
|
event := strings.TrimSpace(line[len(eventPrefix):])
|
|
var data string
|
|
// Read the following lines to get data and id
|
|
for scanner.Scan() {
|
|
nextLine := scanner.Text()
|
|
if nextLine == "" {
|
|
break
|
|
}
|
|
if strings.HasPrefix(nextLine, dataPrefix) {
|
|
data = nextLine[len(dataPrefix):]
|
|
} else if strings.HasPrefix(nextLine, "id:") {
|
|
// id = strings.TrimSpace(nextLine[len("id:"):])
|
|
}
|
|
}
|
|
|
|
if event == "output" {
|
|
render.StringData(c, data)
|
|
responseText += data
|
|
} else if event == "done" {
|
|
render.Done(c)
|
|
doneRendered = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return "", errors.Wrap(err, "scan stream")
|
|
}
|
|
|
|
if !doneRendered {
|
|
render.Done(c)
|
|
}
|
|
|
|
return responseText, nil
|
|
}
|