From 5a58426859e6c128392079e80544c340316db307 Mon Sep 17 00:00:00 2001 From: Ghostz <137054651+ye4293@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:09:16 +0800 Subject: [PATCH 01/11] fix minimax empty log (#1560) --- relay/adaptor/openai/main.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 72c675e1..07cb967f 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -4,15 +4,16 @@ import ( "bufio" "bytes" "encoding/json" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) const ( @@ -149,7 +150,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - if textResponse.Usage.TotalTokens == 0 { + if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { completionTokens += CountTokenText(choice.Message.StringContent(), modelName) From 8cc1ee63605d36cc20d096e2be786fc533870833 Mon Sep 17 00:00:00 2001 From: Leo Q Date: Sun, 30 Jun 2024 16:12:16 +0800 Subject: [PATCH 02/11] ci: use codecov to upload coverage report (#1583) --- .github/workflows/ci.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 89ba75cd..698acdf1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,17 +45,15 @@ jobs: code_coverage: name: "Code coverage report" - if: github.event_name == 'pull_request' # Do not run when workflow is triggered by push to main branch runs-on: ubuntu-latest needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job steps: - - uses: fgrosse/go-coverage-report@v1.0.2 # Consider using a Git revision for maximum security - with: - coverage-artifact-name: "code-coverage" # can be omitted if you used this default value - coverage-file-name: "coverage.txt" # can be omitted if you used this default value + - uses: codecov/codecov-action@v4 + with: + use_oidc: true commit_lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: wagoid/commitlint-github-action@v6 \ No newline at end of file + - uses: wagoid/commitlint-github-action@v6 From 34cb147a744e717404ebccd566cdf1b753ef78a1 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:13:43 +0800 Subject: [PATCH 03/11] refactor: replace hardcoded string with ctxkey constant (#1579) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 江杭辉 --- common/ctxkey/key.go | 1 + common/gin.go | 7 +++---- controller/relay.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 6c640870..90556b3a 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -19,4 +19,5 @@ const ( TokenName = "token_name" BaseURL = "base_url" AvailableModels = "available_models" + KeyRequestBody = "key_request_body" ) diff --git a/common/gin.go b/common/gin.go index b6ef96a6..549d3279 100644 --- a/common/gin.go +++ b/common/gin.go @@ -4,14 +4,13 @@ import ( "bytes" "encoding/json" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/ctxkey" "io" "strings" ) -const KeyRequestBody = "key_request_body" - func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(KeyRequestBody) + requestBody, _ := c.Get(ctxkey.KeyRequestBody) if requestBody != nil { return requestBody.([]byte), nil } @@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) { return nil, err } _ = c.Request.Body.Close() - c.Set(KeyRequestBody, requestBody) + c.Set(ctxkey.KeyRequestBody, requestBody) return requestBody.([]byte), nil } diff --git a/controller/relay.go b/controller/relay.go index 5d8ac690..932e023b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -48,7 +48,7 @@ func Relay(c *gin.Context) { logger.Debugf(ctx, "request body: %s", string(requestBody)) } channelId := c.GetInt(ctxkey.ChannelId) - userId := c.GetInt("id") + userId := c.GetInt(ctxkey.Id) bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) From b70a07e814c5907e044f45dac32cb02ab1e51efc Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 30 Jun 2024 16:19:49 +0800 Subject: [PATCH 04/11] fix: fix ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 698acdf1..30ac5f82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,7 +50,7 @@ jobs: steps: - uses: codecov/codecov-action@v4 with: - use_oidc: true + token: ${{ secrets.CODECOV_TOKEN }} commit_lint: runs-on: ubuntu-latest From f25aaf7752a6f1719f445bb3d2d62863774e626b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 30 Jun 2024 16:21:48 +0800 Subject: [PATCH 05/11] chore(deps): bump golang.org/x/image from 0.16.0 to 0.18.0 (#1568) Bumps [golang.org/x/image](https://github.com/golang/image) from 0.16.0 to 0.18.0. - [Commits](https://github.com/golang/image/compare/v0.16.0...v0.18.0) --- updated-dependencies: - dependency-name: golang.org/x/image dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 7a396314..2d0df03f 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.23.0 - golang.org/x/image v0.16.0 + golang.org/x/image v0.18.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 @@ -80,7 +80,7 @@ require ( golang.org/x/net v0.25.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/text v0.16.0 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4c1aac95..ab04845c 100644 --- a/go.sum +++ b/go.sum @@ -154,8 +154,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= -golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= @@ -164,8 +164,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= From ae1cd29f943b31d4c12dffecd166b621b1ac2400 Mon Sep 17 00:00:00 2001 From: shaoyun Date: Sun, 30 Jun 2024 16:25:25 +0800 Subject: [PATCH 06/11] feat: added support for Claude Sonnet 3.5 (#1567) --- relay/adaptor/anthropic/constants.go | 1 + relay/adaptor/aws/main.go | 13 +++++++------ relay/billing/ratio/model.go | 13 +++++++------ web/air/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index cadcedc8..143d1efc 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -5,4 +5,5 @@ var ModelList = []string{ "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", + "claude-3-5-sonnet-20240620", } diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 0776f985..5d29597c 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -33,12 +33,13 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode { // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html var awsModelIDMap = map[string]string{ - "claude-instant-1.2": "anthropic.claude-instant-v1", - "claude-2.0": "anthropic.claude-v2", - "claude-2.1": "anthropic.claude-v2:1", - "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", - "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", } func awsModelID(requestModel string) (string, error) { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 3b289499..b1a8a5b4 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -70,12 +70,13 @@ var ModelRatio = map[string]float64{ "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing - "claude-instant-1.2": 0.8 / 1000 * USD, - "claude-2.0": 8.0 / 1000 * USD, - "claude-2.1": 8.0 / 1000 * USD, - "claude-3-haiku-20240307": 0.25 / 1000 * USD, - "claude-3-sonnet-20240229": 3.0 / 1000 * USD, - "claude-3-opus-20240229": 15.0 / 1000 * USD, + "claude-instant-1.2": 0.8 / 1000 * USD, + "claude-2.0": 8.0 / 1000 * USD, + "claude-2.1": 8.0 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-sonnet-20240229": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index efb2cee8..d63fa8fa 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -63,7 +63,7 @@ const EditChannel = (props) => { let localModels = []; switch (value) { case 14: - localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; + localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; break; case 11: localModels = ['PaLM-2']; From b21b3b5b460502a40217b3973cf6ee5f44c916f9 Mon Sep 17 00:00:00 2001 From: zijiren <84728412+zijiren233@users.noreply.github.com> Date: Sun, 30 Jun 2024 18:36:33 +0800 Subject: [PATCH 07/11] refactor: abusing goroutines and channel (#1561) * refactor: abusing goroutines * fix: trim data prefix * refactor: move functions to render package * refactor: add back trim & flush --------- Co-authored-by: JustSong --- common/render/render.go | 29 +++++++ relay/adaptor/aiproxy/main.go | 97 +++++++++++------------ relay/adaptor/ali/main.go | 91 ++++++++++----------- relay/adaptor/anthropic/main.go | 102 ++++++++++++------------ relay/adaptor/baidu/main.go | 96 ++++++++++------------ relay/adaptor/cloudflare/main.go | 122 +++++++++++++--------------- relay/adaptor/cohere/main.go | 101 +++++++++++------------- relay/adaptor/coze/main.go | 113 ++++++++++++-------------- relay/adaptor/gemini/main.go | 91 +++++++++------------ relay/adaptor/ollama/main.go | 75 +++++++++--------- relay/adaptor/openai/main.go | 131 +++++++++++++------------------ relay/adaptor/palm/palm.go | 93 +++++++++++----------- relay/adaptor/tencent/main.go | 98 ++++++++++------------- relay/adaptor/zhipu/main.go | 103 +++++++++++------------- 14 files changed, 614 insertions(+), 728 deletions(-) create mode 100644 common/render/render.go diff --git a/common/render/render.go b/common/render/render.go new file mode 100644 index 00000000..646b3777 --- /dev/null +++ b/common/render/render.go @@ -0,0 +1,29 @@ +package render + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "strings" +) + +func StringData(c *gin.Context, str string) { + str = strings.TrimPrefix(str, "data: ") + str = strings.TrimSuffix(str, "\r") + c.Render(-1, common.CustomEvent{Data: "data: " + str}) + c.Writer.Flush() +} + +func ObjectData(c *gin.Context, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + StringData(c, string(jsonData)) + return nil +} + +func Done(c *gin.Context) { + StringData(c, "[DONE]") +} diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go index 01a568f6..d64b6809 100644 --- a/relay/adaptor/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -4,6 +4,12 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -12,10 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" ) // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 @@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage + var documents []LibraryDocument scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - var documents []LibraryDocument - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var AIProxyLibraryResponse LibraryStreamResponse - err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if len(AIProxyLibraryResponse.Documents) != 0 { - documents = AIProxyLibraryResponse.Documents - } - response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - response := documentsAIProxyLibrary(documents) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue } - }) - err := resp.Body.Close() + data = data[5:] + + var AIProxyLibraryResponse LibraryStreamResponse + err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if len(AIProxyLibraryResponse.Documents) != 0 { + documents = AIProxyLibraryResponse.Documents + } + response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + response := documentsAIProxyLibrary(documents) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + render.Done(c) + + err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, &usage } diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 0462c26b..f9039dbe 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,15 +3,17 @@ package ali import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r @@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - //lastResponseText := "" - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var aliResponse ChatResponse - err := json.Unmarshal([]byte(data), &aliResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if aliResponse.Usage.OutputTokens != 0 { - usage.PromptTokens = aliResponse.Usage.InputTokens - usage.CompletionTokens = aliResponse.Usage.OutputTokens - usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens - } - response := streamResponseAli2OpenAI(&aliResponse) - if response == nil { - return true - } - //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - //lastResponseText = aliResponse.Output.Text - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || data[:5] != "data:" { + continue } - }) + data = data[5:] + + var aliResponse ChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } + response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + continue + } + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index a8de185c..c817a9d1 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -169,64 +170,59 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { - continue - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - dataChan <- data - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) + var usage model.Usage var modelName string var id string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSpace(data) - var claudeResponse StreamResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, meta := StreamResponseClaude2OpenAI(&claudeResponse) - if meta != nil { - usage.PromptTokens += meta.Usage.InputTokens - usage.CompletionTokens += meta.Usage.OutputTokens - modelName = meta.Model - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - return true - } - if response == nil { - return true - } - response.Id = id - response.Model = modelName - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + + var claudeResponse StreamResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, meta := StreamResponseClaude2OpenAI(&claudeResponse) + if meta != nil { + usage.PromptTokens += meta.Usage.InputTokens + usage.CompletionTokens += meta.Usage.OutputTokens + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + continue + } + if response == nil { + continue + } + + response.Id = id + response.Model = modelName + response.Created = createdTime + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } return nil, &usage } diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index b816e0f4..ebe70c32 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -5,6 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" @@ -12,11 +19,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "sync" - "time" ) // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 @@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var usage model.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - data = data[6:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var baiduResponse ChatStreamResponse - err := json.Unmarshal([]byte(data), &baiduResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if baiduResponse.Usage.TotalTokens != 0 { - usage.TotalTokens = baiduResponse.Usage.TotalTokens - usage.PromptTokens = baiduResponse.Usage.PromptTokens - usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens - } - response := streamResponseBaidu2OpenAI(&baiduResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { + continue } - }) + data = data[6:] + + var baiduResponse ChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := streamResponseBaidu2OpenAI(&baiduResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go index f6d496f7..c76520a2 100644 --- a/relay/adaptor/cloudflare/main.go +++ b/relay/adaptor/cloudflare/main.go @@ -2,8 +2,8 @@ package cloudflare import ( "bufio" - "bytes" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -17,21 +17,20 @@ import ( ) func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { - var promptBuilder strings.Builder - for _, message := range textRequest.Messages { - promptBuilder.WriteString(message.StringContent()) - promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 - } + var promptBuilder strings.Builder + for _, message := range textRequest.Messages { + promptBuilder.WriteString(message.StringContent()) + promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 + } - return &Request{ - MaxTokens: textRequest.MaxTokens, - Prompt: promptBuilder.String(), - Stream: textRequest.Stream, - Temperature: textRequest.Temperature, - } + return &Request{ + MaxTokens: textRequest.MaxTokens, + Prompt: promptBuilder.String(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } } - func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: 0, @@ -63,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < len("data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - dataChan <- data - } - stopChan <- true - }() common.SetEventStreamHeaders(c) id := helper.GetResponseID(c) responseModel := c.GetString("original_model") var responseText string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cloudflareResponse StreamResponse - err := json.Unmarshal([]byte(data), &cloudflareResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) - if response == nil { - return true - } - responseText += cloudflareResponse.Response - response.Id = id - response.Model = responseModel - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + var cloudflareResponse StreamResponse + err := json.Unmarshal([]byte(data), &cloudflareResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) + if response == nil { + continue + } + + responseText += cloudflareResponse.Response + response.Id = id + response.Model = responseModel + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) return nil, usage } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 4bc3fa8d..45db437b 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -2,9 +2,9 @@ package cohere import ( "bufio" - "bytes" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - dataChan <- data - } - stopChan <- true - }() common.SetEventStreamHeaders(c) var usage model.Usage - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cohereResponse StreamResponse - err := json.Unmarshal([]byte(data), &cohereResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, meta := StreamResponseCohere2OpenAI(&cohereResponse) - if meta != nil { - usage.PromptTokens += meta.Meta.Tokens.InputTokens - usage.CompletionTokens += meta.Meta.Tokens.OutputTokens - return true - } - if response == nil { - return true - } - response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) - response.Model = c.GetString("original_model") - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSuffix(data, "\r") + + var cohereResponse StreamResponse + err := json.Unmarshal([]byte(data), &cohereResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue } - }) - _ = resp.Body.Close() + + response, meta := StreamResponseCohere2OpenAI(&cohereResponse) + if meta != nil { + usage.PromptTokens += meta.Meta.Tokens.InputTokens + usage.CompletionTokens += meta.Meta.Tokens.OutputTokens + continue + } + if response == nil { + continue + } + + response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) + response.Model = c.GetString("original_model") + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage } diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go index 721c5d13..d0402a76 100644 --- a/relay/adaptor/coze/main.go +++ b/relay/adaptor/coze/main.go @@ -4,6 +4,11 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" @@ -12,9 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) // https://www.coze.com/open @@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var responseText string createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { - continue - } - if !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) var modelName string - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var cozeResponse StreamResponse - err := json.Unmarshal([]byte(data), &cozeResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, _ := StreamResponseCoze2OpenAI(&cozeResponse) - if response == nil { - return true - } - for _, choice := range response.Choices { - responseText += conv.AsString(choice.Delta.Content) - } - response.Model = modelName - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) - _ = resp.Body.Close() + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSuffix(data, "\r") + + var cozeResponse StreamResponse + err := json.Unmarshal([]byte(data), &cozeResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response, _ := StreamResponseCoze2OpenAI(&cozeResponse) + if response == nil { + continue + } + + for _, choice := range response.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + response.Model = modelName + response.Created = createdTime + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &responseText } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 74a7d5d5..51fd6aa8 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -275,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\"") - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var geminiResponse ChatResponse - err := json.Unmarshal([]byte(data), &geminiResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseGeminiChat2OpenAI(&geminiResponse) - if response == nil { - return true - } - responseText += response.Choices[0].Delta.StringContent() - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue } - }) + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + + var geminiResponse ChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue + } + + responseText += response.Choices[0].Delta.StringContent() + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index c5fe08e6..936a7e14 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/image" @@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return 0, nil, nil } if i := strings.Index(string(data), "}\n"); i >= 0 { - return i + 2, data[0:i], nil + return i + 2, data[0 : i+1], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := strings.TrimPrefix(scanner.Text(), "}") - dataChan <- data + "}" - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var ollamaResponse ChatResponse - err := json.Unmarshal([]byte(data), &ollamaResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if ollamaResponse.EvalCount != 0 { - usage.PromptTokens = ollamaResponse.PromptEvalCount - usage.CompletionTokens = ollamaResponse.EvalCount - usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount - } - response := streamResponseOllama2OpenAI(&ollamaResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + data = data + "}" + + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue } - }) + + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + + response := streamResponseOllama2OpenAI(&ollamaResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, &usage } diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 07cb967f..1d534644 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -25,88 +26,68 @@ const ( func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) + scanner.Split(bufio.ScanLines) var usage *model.Usage - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < dataPrefixLength { // ignore blank line or wrong format - continue - } - if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { - continue - } - if strings.HasPrefix(data[dataPrefixLength:], done) { - dataChan <- data - continue - } - switch relayMode { - case relaymode.ChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - dataChan <- data // if error happened, pass the data to client - continue // just ignore the error - } - if len(streamResponse.Choices) == 0 { - // but for empty choice, we should not pass it to client, this is for azure - continue // just ignore empty choice - } - dataChan <- data - for _, choice := range streamResponse.Choices { - responseText += conv.AsString(choice.Delta.Content) - } - if streamResponse.Usage != nil { - usage = streamResponse.Usage - } - case relaymode.Completions: - dataChan <- data - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } - } - } - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < dataPrefixLength { // ignore blank line or wrong format + continue } - }) + if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { + continue + } + if strings.HasPrefix(data[dataPrefixLength:], done) { + render.StringData(c, data) + continue + } + switch relayMode { + case relaymode.ChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + render.StringData(c, data) // if error happened, pass the data to client + continue // just ignore the error + } + if len(streamResponse.Choices) == 0 { + // but for empty choice, we should not pass it to client, this is for azure + continue // just ignore empty choice + } + render.StringData(c, data) + for _, choice := range streamResponse.Choices { + responseText += conv.AsString(choice.Delta.Content) + } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } + case relaymode.Completions: + render.StringData(c, data) + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text + } + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } + return nil, responseText, usage } diff --git a/relay/adaptor/palm/palm.go b/relay/adaptor/palm/palm.go index 1e60e7cd..d31784ec 100644 --- a/relay/adaptor/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -3,6 +3,10 @@ package palm import ( "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -11,8 +15,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body @@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) createdTime := helper.GetTimestamp() - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - logger.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return - } - err = resp.Body.Close() - if err != nil { - logger.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var palmResponse ChatResponse - err = json.Unmarshal(responseBody, &palmResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(palmResponse.Candidates) > 0 { - responseText = palmResponse.Candidates[0].Content - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) - stopChan <- true - }() + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + logger.SysError("error reading stream response: " + err.Error()) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - }) - err := resp.Body.Close() + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), "" + } + + err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + + var palmResponse ChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), "" + } + + fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(palmResponse.Candidates) > 0 { + responseText = palmResponse.Candidates[0].Content + } + + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" + } + + err = render.ObjectData(c, string(jsonResponse)) + if err != nil { + logger.SysError(err.Error()) + } + + render.Done(c) + return nil, responseText } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 0a57dcf7..365e33ae 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,6 +8,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strconv" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" @@ -17,11 +24,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" - "time" ) func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { @@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() + scanner.Split(bufio.ScanLines) + common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var TencentResponse ChatResponse - err := json.Unmarshal([]byte(data), &TencentResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response := streamResponseTencent2OpenAI(&TencentResponse) - if len(response.Choices) != 0 { - responseText += conv.AsString(response.Choices[0].Delta.Content) - } - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 || !strings.HasPrefix(data, "data:") { + continue } - }) + data = strings.TrimPrefix(data, "data:") + + var tencentResponse ChatResponse + err := json.Unmarshal([]byte(data), &tencentResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + + response := streamResponseTencent2OpenAI(&tencentResponse) + if len(response.Choices) != 0 { + responseText += conv.AsString(response.Choices[0].Delta.Content) + } + + err = render.ObjectData(c, response) + if err != nil { + logger.SysError(err.Error()) + } + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } + return nil, responseText } diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go index 74a1a05e..ab3a5678 100644 --- a/relay/adaptor/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -3,6 +3,13 @@ package zhipu import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" + "io" + "net/http" + "strings" + "sync" + "time" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/songquanpeng/one-api/common" @@ -11,11 +18,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "sync" - "time" ) // https://open.bigmodel.cn/doc/api#chatglm_std @@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - dataChan := make(chan string) - metaChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - lines := strings.Split(data, "\n") - for i, line := range lines { - if len(line) < 5 { + + common.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + lines := strings.Split(data, "\n") + for i, line := range lines { + if len(line) < 5 { + continue + } + if strings.HasPrefix(line, "data:") { + dataSegment := line[5:] + if i != len(lines)-1 { + dataSegment += "\n" + } + response := streamResponseZhipu2OpenAI(dataSegment) + err := render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + } + } else if strings.HasPrefix(line, "meta:") { + metaSegment := line[5:] + var zhipuResponse StreamMetaResponse + err := json.Unmarshal([]byte(metaSegment), &zhipuResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } - if line[:5] == "data:" { - dataChan <- line[5:] - if i != len(lines)-1 { - dataChan <- "\n" - } - } else if line[:5] == "meta:" { - metaChan <- line[5:] + response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) + err = render.ObjectData(c, response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) } + usage = zhipuUsage } } - stopChan <- true - }() - common.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - response := streamResponseZhipu2OpenAI(data) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case data := <-metaChan: - var zhipuResponse StreamMetaResponse - err := json.Unmarshal([]byte(data), &zhipuResponse) - if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) - return true - } - usage = zhipuUsage - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) + } + + if err := scanner.Err(); err != nil { + logger.SysError("error reading stream: " + err.Error()) + } + + render.Done(c) + err := resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + return nil, usage } From d0369b114f6b9a34a926979b309ac1fd052db698 Mon Sep 17 00:00:00 2001 From: lihangfu <280001404@qq.com> Date: Sun, 30 Jun 2024 19:37:07 +0800 Subject: [PATCH 08/11] feat: support spark4.0 ultra (#1569) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: 支持v3最新协议的腾讯混元(#1452) * feat: 支持Spark4.0 Ultra --------- Co-authored-by: lihangfu --- relay/adaptor/xunfei/constants.go | 1 + relay/adaptor/xunfei/main.go | 2 ++ relay/billing/ratio/model.go | 1 + web/air/src/pages/Channel/EditChannel.js | 2 +- web/berry/src/views/Channel/type/Config.js | 2 +- 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go index 31dcec71..12a56210 100644 --- a/relay/adaptor/xunfei/constants.go +++ b/relay/adaptor/xunfei/constants.go @@ -6,4 +6,5 @@ var ModelList = []string{ "SparkDesk-v2.1", "SparkDesk-v3.1", "SparkDesk-v3.5", + "SparkDesk-v4.0", } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 39b76e27..7cf413a4 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string { return "generalv3" case "v3.5": return "generalv3.5" + case "v4.0": + return "4.0Ultra" } return "general" + apiVersion } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index b1a8a5b4..56d31e13 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -125,6 +125,7 @@ var ModelRatio = map[string]float64{ "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index d63fa8fa..73fd2da2 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -78,7 +78,7 @@ const EditChannel = (props) => { localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: - localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; + localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; break; case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 88e1ea92..51b7c6c4 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -91,7 +91,7 @@ const typeConfig = { other: '版本号' }, input: { - models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] + models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'] }, prompt: { key: '按照如下格式输入:APPID|APISecret|APIKey', From c135d74f136813ff26731c2e78bcf2fc3dc3daed Mon Sep 17 00:00:00 2001 From: Shi Jilin <40982122+shijilin0116@users.noreply.github.com> Date: Sun, 30 Jun 2024 19:38:02 +0800 Subject: [PATCH 09/11] feat: support Spark4.0 Ultra (#1575) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fix SparkDesk Function Call (修复 Spark Pro/Max函数调用只会返回普通对话回答而不是Function Call回答的问题 * feat: support Spark4.0 Ultra --- relay/adaptor/xunfei/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 7cf413a4..ef6120e5 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Payload.Message.Text = messages - if strings.HasPrefix(domain, "generalv3") { + if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" { functions := make([]model.Function, len(request.Tools)) for i, tool := range request.Tools { functions[i] = tool.Function From fecaece71b700b43ba11c161a3f8a971af204971 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 30 Jun 2024 19:52:33 +0800 Subject: [PATCH 10/11] fix: fix size not support during image generation (#1564) Fixes #1224, #1068 --- relay/controller/helper.go | 72 -------------------------------- relay/controller/image.go | 84 +++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 78 deletions(-) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486..c47cb558 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { - imageRequest := &relaymodel.ImageRequest{} - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - return imageRequest, nil -} - -func isValidImageSize(model string, size string) bool { - if model == "cogview-3" { - return true - } - _, ok := billingratio.ImageSizeRatios[model][size] - return ok -} - -func getImageSizeRatio(model string, size string) float64 { - ratio, ok := billingratio.ImageSizeRatios[model][size] - if !ok { - return 1 - } - return ratio -} - -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { - // model validation - hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) - if !hasValidSize { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - // check prompt length - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - // Number of generated images validation - if !isWithinRange(imageRequest.Model, imageRequest.N) { - // channel not azure - if meta.ChannelType != channeltype.Azure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } - } - return nil -} - -func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { - if imageRequest == nil { - return 0, errors.New("imageRequest is nil") - } - imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) - if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { - if imageRequest.Size == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - return imageCostRatio, nil -} - func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: diff --git a/relay/controller/image.go b/relay/controller/image.go index 691c7c0e..e6245226 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -20,13 +21,84 @@ import ( "net/http" ) -func isWithinRange(element string, value int) bool { - if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { - return false +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err } - min := billingratio.ImageGenerationAmounts[element][0] - max := billingratio.ImageGenerationAmounts[element][1] - return value >= min && value <= max + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func isValidImagePromptLength(model string, promptLength int) bool { + maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] + return !ok || promptLength <= maxPromptLength +} + +func isWithinRange(element string, value int) bool { + amounts, ok := billingratio.ImageGenerationAmounts[element] + return !ok || (value >= amounts[0] && value <= amounts[1]) +} + +func getImageSizeRatio(model string, size string) float64 { + if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { + return ratio + } + return 1 +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + + // model validation + if !isValidImageSize(imageRequest.Model, imageRequest.Size) { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + + if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + return nil +} + +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { From d936817de9866b3e8b6bf1a0f741a2a4eb6c3bd4 Mon Sep 17 00:00:00 2001 From: Darkside Date: Sun, 30 Jun 2024 19:57:30 +0800 Subject: [PATCH 11/11] docs: add related projects (#1562) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 成达 --- README.en.md | 10 ++++++---- README.md | 12 +++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/README.en.md b/README.en.md index bce47353..db96a858 100644 --- a/README.en.md +++ b/README.en.md @@ -101,7 +101,7 @@ Nginx reference configuration: ``` server{ server_name openai.justsong.cn; # Modify your domain name accordingly - + location / { client_max_body_size 64m; proxy_http_version 1.1; @@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`. 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: ```shell git clone https://github.com/songquanpeng/one-api.git - + # Build the frontend cd one-api/web/default npm install npm run build - + # Build the backend cd ../.. go mod download @@ -287,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the + Double-check that your interface address and API Key are correct. ## Related Projects -[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM +* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller. +* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization. ## Note This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. diff --git a/README.md b/README.md index 8f59a14a..b5168264 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 > [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 -> +> > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > [!WARNING] @@ -144,7 +144,7 @@ Nginx 的参考配置: ``` server{ server_name openai.justsong.cn; # 请根据实际情况修改你的域名 - + location / { client_max_body_size 64m; proxy_http_version 1.1; @@ -189,12 +189,12 @@ docker-compose ps 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: ```shell git clone https://github.com/songquanpeng/one-api.git - + # 构建前端 cd one-api/web/default npm install npm run build - + # 构建后端 cd ../.. go mod download @@ -321,7 +321,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashbo 例如对于 OpenAI 的官方库: ```bash OPENAI_API_KEY="sk-xxxxxx" -OPENAI_API_BASE="https://:/v1" +OPENAI_API_BASE="https://:/v1" ``` ```mermaid @@ -448,6 +448,8 @@ https://openai.justsong.cn ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 +* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 +* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。 ## 注意