Compare commits

...

17 Commits

Author SHA1 Message Date
Ian Li
43368e68c8 Merge 3462bd538c into c4fe57c165 2024-07-03 21:04:26 +08:00
LinZeliang
c4fe57c165 feat: support one or more log file (#1400)
Co-authored-by: Laisky.Cai <github@laisky.com>
2024-07-03 20:53:29 +08:00
igophper
274fcf3d76 refactor: init db (#1590)
Co-authored-by: 江杭辉 <jianghanghui@k.app>
2024-07-03 20:50:40 +08:00
Mikey
0fc07ea558 feat: add support for Claude 3 tool use (function calling) (#1587)
* feat: add tool support for AWS & Claude

* fix: add {} for openai compatibility in streaming tool_use
2024-07-02 00:12:01 +08:00
Leo Q
1ce1e529ee ci: skip archive, upload directly (#1586) 2024-07-02 00:05:47 +08:00
Darkside
d936817de9 docs: add related projects (#1562)
Co-authored-by: 成达 <chengda.615@bytedance.com>
2024-06-30 19:57:30 +08:00
igophper
fecaece71b fix: fix size not support during image generation (#1564)
Fixes #1224, #1068
2024-06-30 19:52:33 +08:00
Shi Jilin
c135d74f13 feat: support Spark4.0 Ultra (#1575)
* fix: fix SparkDesk Function Call (修复 Spark Pro/Max函数调用只会返回普通对话回答而不是Function Call回答的问题

* feat: support Spark4.0 Ultra
2024-06-30 19:38:02 +08:00
lihangfu
d0369b114f feat: support spark4.0 ultra (#1569)
* feat: 支持v3最新协议的腾讯混元(#1452)

* feat: 支持Spark4.0 Ultra

---------

Co-authored-by: lihangfu <hfli8@iflytek.com>
2024-06-30 19:37:07 +08:00
zijiren
b21b3b5b46 refactor: abusing goroutines and channel (#1561)
* refactor: abusing goroutines

* fix: trim data prefix

* refactor: move functions to render package

* refactor: add back trim & flush

---------

Co-authored-by: JustSong <quanpengsong@gmail.com>
2024-06-30 18:36:33 +08:00
shaoyun
ae1cd29f94 feat: added support for Claude Sonnet 3.5 (#1567) 2024-06-30 16:25:25 +08:00
dependabot[bot]
f25aaf7752 chore(deps): bump golang.org/x/image from 0.16.0 to 0.18.0 (#1568)
Bumps [golang.org/x/image](https://github.com/golang/image) from 0.16.0 to 0.18.0.
- [Commits](https://github.com/golang/image/compare/v0.16.0...v0.18.0)

---
updated-dependencies:
- dependency-name: golang.org/x/image
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-30 16:21:48 +08:00
JustSong
b70a07e814 fix: fix ci 2024-06-30 16:19:49 +08:00
igophper
34cb147a74 refactor: replace hardcoded string with ctxkey constant (#1579)
Co-authored-by: 江杭辉 <jianghanghui@k.app>
2024-06-30 16:13:43 +08:00
Leo Q
8cc1ee6360 ci: use codecov to upload coverage report (#1583) 2024-06-30 16:12:16 +08:00
Ghostz
5a58426859 fix minimax empty log (#1560) 2024-06-30 16:09:16 +08:00
Ian Li
3462bd538c feat: support LOGIN as SMTP authentication method. 2024-03-16 16:29:10 +08:00
43 changed files with 1115 additions and 965 deletions

View File

@@ -36,26 +36,12 @@ jobs:
# in the next step as well as the next job. # in the next step as well as the next job.
- name: Test - name: Test
run: go test -cover -coverprofile=coverage.txt ./... run: go test -cover -coverprofile=coverage.txt ./...
- uses: codecov/codecov-action@v4
- name: Archive code coverage results
uses: actions/upload-artifact@v4
with: with:
name: code-coverage token: ${{ secrets.CODECOV_TOKEN }}
path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step
code_coverage:
name: "Code coverage report"
if: github.event_name == 'pull_request' # Do not run when workflow is triggered by push to main branch
runs-on: ubuntu-latest
needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job
steps:
- uses: fgrosse/go-coverage-report@v1.0.2 # Consider using a Git revision for maximum security
with:
coverage-artifact-name: "code-coverage" # can be omitted if you used this default value
coverage-file-name: "coverage.txt" # can be omitted if you used this default value
commit_lint: commit_lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: wagoid/commitlint-github-action@v6 - uses: wagoid/commitlint-github-action@v6

View File

@@ -101,7 +101,7 @@ Nginx reference configuration:
``` ```
server{ server{
server_name openai.justsong.cn; # Modify your domain name accordingly server_name openai.justsong.cn; # Modify your domain name accordingly
location / { location / {
client_max_body_size 64m; client_max_body_size 64m;
proxy_http_version 1.1; proxy_http_version 1.1;
@@ -132,12 +132,12 @@ The initial account username is `root` and password is `123456`.
1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source: 1. Download the executable file from [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) or compile from source:
```shell ```shell
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# Build the frontend # Build the frontend
cd one-api/web/default cd one-api/web/default
npm install npm install
npm run build npm run build
# Build the backend # Build the backend
cd ../.. cd ../..
go mod download go mod download
@@ -287,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct. + Double-check that your interface address and API Key are correct.
## Related Projects ## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM * [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller.
* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
## Note ## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.

View File

@@ -53,7 +53,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
> [!NOTE] > [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> >
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!WARNING] > [!WARNING]
@@ -144,7 +144,7 @@ Nginx 的参考配置:
``` ```
server{ server{
server_name openai.justsong.cn; # 请根据实际情况修改你的域名 server_name openai.justsong.cn; # 请根据实际情况修改你的域名
location / { location / {
client_max_body_size 64m; client_max_body_size 64m;
proxy_http_version 1.1; proxy_http_version 1.1;
@@ -189,12 +189,12 @@ docker-compose ps
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell ```shell
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# 构建前端 # 构建前端
cd one-api/web/default cd one-api/web/default
npm install npm install
npm run build npm run build
# 构建后端 # 构建后端
cd ../.. cd ../..
go mod download go mod download
@@ -321,7 +321,7 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashbo
例如对于 OpenAI 的官方库: 例如对于 OpenAI 的官方库:
```bash ```bash
OPENAI_API_KEY="sk-xxxxxx" OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1" OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
``` ```
```mermaid ```mermaid
@@ -448,6 +448,8 @@ https://openai.justsong.cn
## 相关项目 ## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 * [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。
## 注意 ## 注意

View File

@@ -63,6 +63,7 @@ var SMTPPort = 587
var SMTPAccount = "" var SMTPAccount = ""
var SMTPFrom = "" var SMTPFrom = ""
var SMTPToken = "" var SMTPToken = ""
var SMTPAuthLoginEnabled = false
var GitHubClientId = "" var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
@@ -145,6 +146,9 @@ var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1") var GeminiVersion = env.String("GEMINI_VERSION", "v1")
var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "") var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

View File

@@ -19,4 +19,5 @@ const (
TokenName = "token_name" TokenName = "token_name"
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
) )

View File

@@ -4,14 +4,13 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"io" "io"
"strings" "strings"
) )
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) { func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody) requestBody, _ := c.Get(ctxkey.KeyRequestBody)
if requestBody != nil { if requestBody != nil {
return requestBody.([]byte), nil return requestBody.([]byte), nil
} }
@@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
return nil, err return nil, err
} }
_ = c.Request.Body.Close() _ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody) c.Set(ctxkey.KeyRequestBody, requestBody)
return requestBody.([]byte), nil return requestBody.([]byte), nil
} }

View File

@@ -27,7 +27,12 @@ var setupLogOnce sync.Once
func SetupLogger() { func SetupLogger() {
setupLogOnce.Do(func() { setupLogOnce.Do(func() {
if LogDir != "" { if LogDir != "" {
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) var logPath string
if config.OnlyOneLogFile {
logPath = filepath.Join(LogDir, "oneapi.log")
} else {
logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
}
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")

View File

@@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"net/smtp" "net/smtp"
@@ -11,6 +12,32 @@ import (
"time" "time"
) )
type loginAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &loginAuth{username, password}
}
func (a *loginAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte(a.username), nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown command from server during login auth")
}
}
return nil, nil
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if receiver == "" { if receiver == "" {
return fmt.Errorf("receiver is empty") return fmt.Errorf("receiver is empty")
@@ -41,7 +68,12 @@ func SendEmail(subject string, receiver string, content string) error {
"Date: %s\r\n"+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) var auth smtp.Auth
if config.SMTPAuthLoginEnabled {
auth = LoginAuth(config.SMTPAccount, config.SMTPToken)
} else {
auth = smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
}
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")

29
common/render/render.go Normal file
View File

@@ -0,0 +1,29 @@
package render
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"strings"
)
func StringData(c *gin.Context, str string) {
str = strings.TrimPrefix(str, "data: ")
str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
c.Writer.Flush()
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

View File

@@ -48,7 +48,7 @@ func Relay(c *gin.Context) {
logger.Debugf(ctx, "request body: %s", string(requestBody)) logger.Debugf(ctx, "request body: %s", string(requestBody))
} }
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
userId := c.GetInt("id") userId := c.GetInt(ctxkey.Id)
bizErr := relayHelper(c, relayMode) bizErr := relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
monitor.Emit(channelId, true) monitor.Emit(channelId, true)

4
go.mod
View File

@@ -24,7 +24,7 @@ require (
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.23.0 golang.org/x/crypto v0.23.0
golang.org/x/image v0.16.0 golang.org/x/image v0.18.0
gorm.io/driver/mysql v1.5.6 gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlite v1.5.5
@@ -80,7 +80,7 @@ require (
golang.org/x/net v0.25.0 // indirect golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect golang.org/x/text v0.16.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

8
go.sum
View File

@@ -154,8 +154,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
@@ -164,8 +164,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

View File

@@ -330,6 +330,7 @@
"通常和邮箱地址保持一致": "Usually consistent with the email address", "通常和邮箱地址保持一致": "Usually consistent with the email address",
"SMTP 访问凭证": "SMTP Access Credential", "SMTP 访问凭证": "SMTP Access Credential",
"敏感信息不会发送到前端显示": "Sensitive information will not be displayed in the frontend", "敏感信息不会发送到前端显示": "Sensitive information will not be displayed in the frontend",
"使用 SMTP LOGIN 认证方式": "Use LOGIN as SMTP authentication method",
"保存 SMTP 设置": "Save SMTP Settings", "保存 SMTP 设置": "Save SMTP Settings",
"配置 GitHub OAuth App": "Configure GitHub OAuth App", "配置 GitHub OAuth App": "Configure GitHub OAuth App",
"用以支持通过 GitHub 进行登录注册": "To support login & registration via GitHub", "用以支持通过 GitHub 进行登录注册": "To support login & registration via GitHub",

22
main.go
View File

@@ -27,27 +27,19 @@ func main() {
common.Init() common.Init()
logger.SetupLogger() logger.SetupLogger()
logger.SysLogf("One API %s started", common.Version) logger.SysLogf("One API %s started", common.Version)
if os.Getenv("GIN_MODE") != "debug" {
if os.Getenv("GIN_MODE") != gin.DebugMode {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if config.DebugEnabled { if config.DebugEnabled {
logger.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
var err error
// Initialize SQL Database // Initialize SQL Database
model.DB, err = model.InitDB("SQL_DSN") model.InitDB()
if err != nil { model.InitLogDB()
logger.FatalLog("failed to initialize database: " + err.Error())
} var err error
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed() err = model.CreateRootAccountIfNeed()
if err != nil { if err != nil {
logger.FatalLog("database init error: " + err.Error()) logger.FatalLog("database init error: " + err.Error())

View File

@@ -1,6 +1,7 @@
package model package model
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
@@ -60,90 +61,156 @@ func CreateRootAccountIfNeed() error {
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" { dsn := os.Getenv(envName)
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") { switch {
// Use PostgreSQL case strings.HasPrefix(dsn, "postgres://"):
logger.SysLog("using PostgreSQL as database") // Use PostgreSQL
common.UsingPostgreSQL = true return openPostgreSQL(dsn)
return gorm.Open(postgres.New(postgres.Config{ case dsn != "":
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL // Use MySQL
logger.SysLog("using MySQL as database") return openMySQL(dsn)
common.UsingMySQL = true default:
return gorm.Open(mysql.Open(dsn), &gorm.Config{ // Use SQLite
PrepareStmt: true, // precompile SQL return openSQLite()
})
} }
// Use SQLite }
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true func openPostgreSQL(dsn string) (*gorm.DB, error) {
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) logger.SysLog("using PostgreSQL as database")
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
func InitDB(envName string) (db *gorm.DB, err error) { func openMySQL(dsn string) (*gorm.DB, error) {
db, err = chooseDB(envName) logger.SysLog("using MySQL as database")
if err == nil { common.UsingMySQL = true
if config.DebugSQLEnabled { return gorm.Open(mysql.Open(dsn), &gorm.Config{
db = db.Debug() PrepareStmt: true, // precompile SQL
} })
sqlDB, err := db.DB() }
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode { func openSQLite() (*gorm.DB, error) {
return db, err logger.SysLog("SQL_DSN not set, using SQLite as database")
} common.UsingSQLite = true
if common.UsingMySQL { dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout)
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded return gorm.Open(sqlite.Open(dsn), &gorm.Config{
} PrepareStmt: true, // precompile SQL
logger.SysLog("database migration started") })
err = db.AutoMigrate(&Channel{}) }
if err != nil {
return nil, err func InitDB() {
} var err error
err = db.AutoMigrate(&Token{}) DB, err = chooseDB("SQL_DSN")
if err != nil { if err != nil {
return nil, err logger.FatalLog("failed to initialize database: " + err.Error())
} return
err = db.AutoMigrate(&User{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return nil, err
}
logger.SysLog("database migrated")
return db, err
} else {
logger.FatalLog(err)
} }
return db, err
sqlDB := setDBConns(DB)
if !config.IsMasterNode {
return
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
if err = migrateDB(); err != nil {
logger.FatalLog("failed to migrate database: " + err.Error())
return
}
logger.SysLog("database migrated")
}
func migrateDB() error {
var err error
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Token{}); err != nil {
return err
}
if err = DB.AutoMigrate(&User{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Option{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Redemption{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Ability{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Log{}); err != nil {
return err
}
if err = DB.AutoMigrate(&Channel{}); err != nil {
return err
}
return nil
}
func InitLogDB() {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
logger.SysLog("using secondary database for table logs")
var err error
LOG_DB, err = chooseDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
return
}
setDBConns(LOG_DB)
if !config.IsMasterNode {
return
}
logger.SysLog("secondary database migration started")
err = migrateLOGDB()
if err != nil {
logger.FatalLog("failed to migrate secondary database: " + err.Error())
return
}
logger.SysLog("secondary database migrated")
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB()
if err != nil {
logger.FatalLog("failed to connect database: " + err.Error())
return nil
}
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
return sqlDB
} }
func closeDB(db *gorm.DB) error { func closeDB(db *gorm.DB) error {

View File

@@ -45,6 +45,7 @@ func InitOptionMap() {
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
config.OptionMap["SMTPAccount"] = "" config.OptionMap["SMTPAccount"] = ""
config.OptionMap["SMTPToken"] = "" config.OptionMap["SMTPToken"] = ""
config.OptionMap["SMTPAuthLoginEnabled"] = strconv.FormatBool(config.SMTPAuthLoginEnabled)
config.OptionMap["Notice"] = "" config.OptionMap["Notice"] = ""
config.OptionMap["About"] = "" config.OptionMap["About"] = ""
config.OptionMap["HomePageContent"] = "" config.OptionMap["HomePageContent"] = ""
@@ -150,6 +151,8 @@ func updateOptionMap(key string, value string) (err error) {
config.DisplayInCurrencyEnabled = boolValue config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled": case "DisplayTokenStatEnabled":
config.DisplayTokenStatEnabled = boolValue config.DisplayTokenStatEnabled = boolValue
case "SMTPAuthLoginEnabled":
config.SMTPAuthLoginEnabled = boolValue
} }
} }
switch key { switch key {

View File

@@ -4,6 +4,12 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -12,10 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
) )
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
@@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage var usage model.Usage
var documents []LibraryDocument
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool { for scanner.Scan() {
select { data := scanner.Text()
case data := <-dataChan: if len(data) < 5 || data[:5] != "data:" {
var AIProxyLibraryResponse LibraryStreamResponse continue
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = data[5:]
err := resp.Body.Close()
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
response := documentsAIProxyLibrary(documents)
err := render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
render.Done(c)
err = resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }

View File

@@ -3,15 +3,17 @@ package ali
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool { for scanner.Scan() {
select { data := scanner.Text()
case data := <-dataChan: if len(data) < 5 || data[:5] != "data:" {
var aliResponse ChatResponse continue
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = data[5:]
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
continue
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -5,4 +5,5 @@ var ModelList = []string{
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-5-sonnet-20240620",
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -28,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
return "stop" return "stop"
case "max_tokens": case "max_tokens":
return "length" return "length"
case "tool_use":
return "tool_calls"
default: default:
return *reason return *reason
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := Request{ claudeRequest := Request{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
@@ -41,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
TopP: textRequest.TopP, TopP: textRequest.TopP,
TopK: textRequest.TopK, TopK: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools,
}
if len(claudeTools) > 0 {
claudeToolChoice := struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
if function, ok := choice["function"].(map[string]any); ok {
claudeToolChoice.Type = "tool"
claudeToolChoice.Name = function["name"].(string)
}
} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
if toolChoiceType == "any" {
claudeToolChoice.Type = toolChoiceType
}
}
claudeRequest.ToolChoice = claudeToolChoice
} }
if claudeRequest.MaxTokens == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096 claudeRequest.MaxTokens = 4096
@@ -63,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
if message.IsStringContent() { if message.IsStringContent() {
content.Type = "text" content.Type = "text"
content.Text = message.StringContent() content.Text = message.StringContent()
if message.Role == "tool" {
claudeMessage.Role = "user"
content.Type = "tool_result"
content.Content = content.Text
content.Text = ""
content.ToolUseId = message.ToolCallId
}
claudeMessage.Content = append(claudeMessage.Content, content) claudeMessage.Content = append(claudeMessage.Content, content)
for i := range message.ToolCalls {
inputParam := make(map[string]any)
_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
claudeMessage.Content = append(claudeMessage.Content, Content{
Type: "tool_use",
Id: message.ToolCalls[i].Id,
Name: message.ToolCalls[i].Function.Name,
Input: inputParam,
})
}
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue continue
} }
@@ -96,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
var response *Response var response *Response
var responseText string var responseText string
var stopReason string var stopReason string
tools := make([]model.Tool, 0)
switch claudeResponse.Type { switch claudeResponse.Type {
case "message_start": case "message_start":
return nil, claudeResponse.Message return nil, claudeResponse.Message
case "content_block_start": case "content_block_start":
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text responseText = claudeResponse.ContentBlock.Text
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, model.Tool{
Id: claudeResponse.ContentBlock.Id,
Type: "function",
Function: model.Function{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} }
case "content_block_delta": case "content_block_delta":
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text responseText = claudeResponse.Delta.Text
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, model.Tool{
Function: model.Function{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
} }
case "message_delta": case "message_delta":
if claudeResponse.Usage != nil { if claudeResponse.Usage != nil {
@@ -119,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
} }
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText choice.Delta.Content = responseText
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason) finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" { if finishReason != "null" {
@@ -135,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
if len(claudeResponse.Content) > 0 { if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text responseText = claudeResponse.Content[0].Text
} }
tools := make([]model.Tool, 0)
for _, v := range claudeResponse.Content {
if v.Type == "tool_use" {
args, _ := json.Marshal(v.Input)
tools = append(tools, model.Tool{
Id: v.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: model.Function{
Name: v.Name,
Arguments: string(args),
},
})
}
}
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: responseText, Content: responseText,
Name: nil, Name: nil,
ToolCalls: tools,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
@@ -169,64 +261,77 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage var usage model.Usage
var modelName string var modelName string
var id string var id string
c.Stream(func(w io.Writer) bool { var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
select {
case data := <-dataChan: for scanner.Scan() {
// some implementations may add \r at the end of data data := scanner.Text()
data = strings.TrimSpace(data) if len(data) < 6 || !strings.HasPrefix(data, "data:") {
var claudeResponse StreamResponse continue
err := json.Unmarshal([]byte(data), &claudeResponse) }
if err != nil { data = strings.TrimPrefix(data, "data:")
logger.SysError("error unmarshalling stream response: " + err.Error()) data = strings.TrimSpace(data)
return true
} var claudeResponse StreamResponse
response, meta := StreamResponseClaude2OpenAI(&claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if meta != nil { if err != nil {
usage.PromptTokens += meta.Usage.InputTokens logger.SysError("error unmarshalling stream response: " + err.Error())
usage.CompletionTokens += meta.Usage.OutputTokens continue
}
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
modelName = meta.Model modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id) id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true continue
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
} }
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) if response == nil {
_ = resp.Body.Close() continue
}
response.Id = id
response.Model = modelName
response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage return nil, &usage
} }

View File

@@ -16,6 +16,12 @@ type Content struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"` Source *ImageSource `json:"source,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
} }
type Message struct { type Message struct {
@@ -23,6 +29,18 @@ type Message struct {
Content []Content `json:"content"` Content []Content `json:"content"`
} }
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type Request struct { type Request struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@@ -33,6 +51,8 @@ type Request struct {
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
//Metadata `json:"metadata,omitempty"` //Metadata `json:"metadata,omitempty"`
} }
@@ -61,6 +81,7 @@ type Response struct {
type Delta struct { type Delta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason"` StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"` StopSequence *string `json:"stop_sequence"`
} }

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"io" "io"
"net/http" "net/http"
@@ -33,12 +34,13 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{ var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1", "claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
} }
func awsModelID(requestModel string) (string, error) { func awsModelID(requestModel string) (string, error) {
@@ -142,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage var usage relaymodel.Usage
var id string var id string
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events() event, ok := <-stream.Events()
if !ok { if !ok {
@@ -162,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id) if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
return true id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
} }
if response == nil { if response == nil {
return true return true
@@ -171,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
response.Id = id response.Id = id
response.Model = c.GetString(ctxkey.OriginalModel) response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())

View File

@@ -9,9 +9,12 @@ type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31" // AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"` AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"` Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
} }

View File

@@ -5,6 +5,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
@@ -12,11 +19,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"sync"
"time"
) )
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage var usage model.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
var baiduResponse ChatStreamResponse if len(data) < 6 {
err := json.Unmarshal([]byte(data), &baiduResponse) continue
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = data[6:]
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -2,8 +2,8 @@ package cloudflare
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -17,21 +17,20 @@ import (
) )
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
var promptBuilder strings.Builder var promptBuilder strings.Builder
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
promptBuilder.WriteString(message.StringContent()) promptBuilder.WriteString(message.StringContent())
promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
} }
return &Request{ return &Request{
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
Prompt: promptBuilder.String(), Prompt: promptBuilder.String(),
Stream: textRequest.Stream, Stream: textRequest.Stream,
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
} }
} }
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
@@ -63,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c) id := helper.GetResponseID(c)
responseModel := c.GetString("original_model") responseModel := c.GetString("original_model")
var responseText string var responseText string
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// some implementations may add \r at the end of data if len(data) < len("data: ") {
data = strings.TrimSuffix(data, "\r") continue
var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
if response == nil {
return true
}
responseText += cloudflareResponse.Response
response.Id = id
response.Model = responseModel
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data: ")
_ = resp.Body.Close() data = strings.TrimSuffix(data, "\r")
var cloudflareResponse StreamResponse
err := json.Unmarshal([]byte(data), &cloudflareResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
if response == nil {
continue
}
responseText += cloudflareResponse.Response
response.Id = id
response.Model = responseModel
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
return nil, usage return nil, usage
} }

View File

@@ -2,9 +2,9 @@ package cohere
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var usage model.Usage var usage model.Usage
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r")
data = strings.TrimSuffix(data, "\r")
var cohereResponse StreamResponse var cohereResponse StreamResponse
err := json.Unmarshal([]byte(data), &cohereResponse) err := json.Unmarshal([]byte(data), &cohereResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true continue
}
response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
return true
}
if response == nil {
return true
}
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model")
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
_ = resp.Body.Close() response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
if meta != nil {
usage.PromptTokens += meta.Meta.Tokens.InputTokens
usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
continue
}
if response == nil {
continue
}
response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
response.Model = c.GetString("original_model")
response.Created = createdTime
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage return nil, &usage
} }

View File

@@ -4,6 +4,11 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@@ -12,9 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
) )
// https://www.coze.com/open // https://www.coze.com/open
@@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var responseText string var responseText string
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 {
continue
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var modelName string var modelName string
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
// some implementations may add \r at the end of data if len(data) < 5 || !strings.HasPrefix(data, "data:") {
data = strings.TrimSuffix(data, "\r") continue
var cozeResponse StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil {
return true
}
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data:")
_ = resp.Body.Close() data = strings.TrimSuffix(data, "\r")
var cozeResponse StreamResponse
err := json.Unmarshal([]byte(data), &cozeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
if response == nil {
continue
}
for _, choice := range response.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
response.Model = modelName
response.Created = createdTime
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &responseText return nil, &responseText
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -275,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
var geminiResponse ChatResponse data = strings.TrimSpace(data)
err := json.Unmarshal([]byte(data), &geminiResponse) if !strings.HasPrefix(data, "data: ") {
if err != nil { continue
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
return true
}
responseText += response.Choices[0].Delta.StringContent()
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
var geminiResponse ChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
responseText += response.Choices[0].Delta.StringContent()
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }

View File

@@ -5,12 +5,14 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common/random"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
@@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "}\n"); i >= 0 { if i := strings.Index(string(data), "}\n"); i >= 0 {
return i + 2, data[0:i], nil return i + 2, data[0 : i+1], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}")
dataChan <- data + "}"
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := strings.TrimPrefix(scanner.Text(), "}")
var ollamaResponse ChatResponse data = data + "}"
err := json.Unmarshal([]byte(data), &ollamaResponse)
if err != nil { var ollamaResponse ChatResponse
logger.SysError("error unmarshalling stream response: " + err.Error()) err := json.Unmarshal([]byte(data), &ollamaResponse)
return true if err != nil {
} logger.SysError("error unmarshalling stream response: " + err.Error())
if ollamaResponse.EvalCount != 0 { continue
usage.PromptTokens = ollamaResponse.PromptEvalCount
usage.CompletionTokens = ollamaResponse.EvalCount
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
}
response := streamResponseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
})
if ollamaResponse.EvalCount != 0 {
usage.PromptTokens = ollamaResponse.PromptEvalCount
usage.CompletionTokens = ollamaResponse.EvalCount
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
}
response := streamResponseOllama2OpenAI(&ollamaResponse)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }

View File

@@ -4,15 +4,17 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
) )
const ( const (
@@ -24,88 +26,68 @@ const (
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
responseText := "" responseText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
var usage *model.Usage var usage *model.Usage
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < dataPrefixLength { // ignore blank line or wrong format
continue
}
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
dataChan <- data
continue
}
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
dataChan <- data // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
dataChan <- data
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
dataChan <- data
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
if strings.HasPrefix(data, "data: [DONE]") { if len(data) < dataPrefixLength { // ignore blank line or wrong format
data = data[:12] continue
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
} }
}) if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
continue
}
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
render.StringData(c, data)
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
render.StringData(c, data)
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
} }
return nil, responseText, usage return nil, responseText, usage
} }
@@ -149,7 +131,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), modelName) completionTokens += CountTokenText(choice.Message.StringContent(), modelName)

View File

@@ -3,6 +3,10 @@ package palm
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -11,8 +15,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
) )
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
logger.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { responseBody, err := io.ReadAll(resp.Body)
case data := <-dataChan: if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + data}) logger.SysError("error reading stream response: " + err.Error())
return true err := resp.Body.Close()
case <-stopChan: if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return false
} }
}) return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
err := resp.Body.Close() }
err = resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
var palmResponse ChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
}
err = render.ObjectData(c, string(jsonResponse))
if err != nil {
logger.SysError(err.Error())
}
render.Done(c)
return nil, responseText return nil, responseText
} }

View File

@@ -8,6 +8,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
@@ -17,11 +24,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
"time"
) )
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
@@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string var responseText string
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select { for scanner.Scan() {
case data := <-dataChan: data := scanner.Text()
var TencentResponse ChatResponse if len(data) < 5 || !strings.HasPrefix(data, "data:") {
err := json.Unmarshal([]byte(data), &TencentResponse) continue
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
} }
}) data = strings.TrimPrefix(data, "data:")
var tencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }

View File

@@ -6,4 +6,5 @@ var ModelList = []string{
"SparkDesk-v2.1", "SparkDesk-v2.1",
"SparkDesk-v3.1", "SparkDesk-v3.1",
"SparkDesk-v3.5", "SparkDesk-v3.5",
"SparkDesk-v4.0",
} }

View File

@@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages xunfeiRequest.Payload.Message.Text = messages
if strings.HasPrefix(domain, "generalv3") { if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
functions := make([]model.Function, len(request.Tools)) functions := make([]model.Function, len(request.Tools))
for i, tool := range request.Tools { for i, tool := range request.Tools {
functions[i] = tool.Function functions[i] = tool.Function
@@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3" return "generalv3"
case "v3.5": case "v3.5":
return "generalv3.5" return "generalv3.5"
case "v4.0":
return "4.0Ultra"
} }
return "general" + apiVersion return "general" + apiVersion
} }

View File

@@ -3,6 +3,13 @@ package zhipu
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
@@ -11,11 +18,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"sync"
"time"
) )
// https://open.bigmodel.cn/doc/api#chatglm_std // https://open.bigmodel.cn/doc/api#chatglm_std
@@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
return 0, nil, nil return 0, nil, nil
}) })
dataChan := make(chan string)
metaChan := make(chan string) common.SetEventStreamHeaders(c)
stopChan := make(chan bool)
go func() { for scanner.Scan() {
for scanner.Scan() { data := scanner.Text()
data := scanner.Text() lines := strings.Split(data, "\n")
lines := strings.Split(data, "\n") for i, line := range lines {
for i, line := range lines { if len(line) < 5 {
if len(line) < 5 { continue
}
if strings.HasPrefix(line, "data:") {
dataSegment := line[5:]
if i != len(lines)-1 {
dataSegment += "\n"
}
response := streamResponseZhipu2OpenAI(dataSegment)
err := render.ObjectData(c, response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
}
} else if strings.HasPrefix(line, "meta:") {
metaSegment := line[5:]
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(metaSegment), &zhipuResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue continue
} }
if line[:5] == "data:" { response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
dataChan <- line[5:] err = render.ObjectData(c, response)
if i != len(lines)-1 { if err != nil {
dataChan <- "\n" logger.SysError("error marshalling stream response: " + err.Error())
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
} }
usage = zhipuUsage
} }
} }
stopChan <- true }
}()
common.SetEventStreamHeaders(c) if err := scanner.Err(); err != nil {
c.Stream(func(w io.Writer) bool { logger.SysError("error reading stream: " + err.Error())
select { }
case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data) render.Done(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, usage return nil, usage
} }

View File

@@ -70,12 +70,13 @@ var ModelRatio = map[string]float64{
"dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
"dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing // https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD, "claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD, "claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD, "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K": 0.012 * RMB,
@@ -124,6 +125,7 @@ var ModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens

View File

@@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil return textRequest, nil
} }
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// model validation
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
}
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode { switch relayMode {
case relaymode.ChatCompletions: case relaymode.ChatCompletions:

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@@ -20,13 +21,84 @@ import (
"net/http" "net/http"
) )
func isWithinRange(element string, value int) bool { func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { imageRequest := &relaymodel.ImageRequest{}
return false err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
} }
min := billingratio.ImageGenerationAmounts[element][0] if imageRequest.N == 0 {
max := billingratio.ImageGenerationAmounts[element][1] imageRequest.N = 1
return value >= min && value <= max }
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func isValidImagePromptLength(model string, promptLength int) bool {
maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
return !ok || promptLength <= maxPromptLength
}
func isWithinRange(element string, value int) bool {
amounts, ok := billingratio.ImageGenerationAmounts[element]
return !ok || (value >= amounts[0] && value <= amounts[1])
}
func getImageSizeRatio(model string, size string) float64 {
if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
return ratio
}
return 1
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// model validation
if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
} }
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {

View File

@@ -1,10 +1,11 @@
package model package model
type Message struct { type Message struct {
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"` Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"` ToolCalls []Tool `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
} }
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {

View File

@@ -2,13 +2,13 @@ package model
type Tool struct { type Tool struct {
Id string `json:"id,omitempty"` Id string `json:"id,omitempty"`
Type string `json:"type"` Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
Function Function `json:"function"` Function Function `json:"function"`
} }
type Function struct { type Function struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name"` Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response Arguments any `json:"arguments,omitempty"` // response
} }

View File

@@ -63,7 +63,7 @@ const EditChannel = (props) => {
let localModels = []; let localModels = [];
switch (value) { switch (value) {
case 14: case 14:
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]; localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
break; break;
case 11: case 11:
localModels = ['PaLM-2']; localModels = ['PaLM-2'];
@@ -78,7 +78,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break; break;
case 18: case 18:
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']; localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
break; break;
case 19: case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];

View File

@@ -91,7 +91,7 @@ const typeConfig = {
other: '版本号' other: '版本号'
}, },
input: { input: {
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'] models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
}, },
prompt: { prompt: {
key: '按照如下格式输入APPID|APISecret|APIKey', key: '按照如下格式输入APPID|APISecret|APIKey',

View File

@@ -18,6 +18,7 @@ const SystemSetting = () => {
SMTPAccount: '', SMTPAccount: '',
SMTPFrom: '', SMTPFrom: '',
SMTPToken: '', SMTPToken: '',
SMTPAuthLoginEnabled: '',
ServerAddress: '', ServerAddress: '',
Footer: '', Footer: '',
WeChatAuthEnabled: '', WeChatAuthEnabled: '',
@@ -76,6 +77,7 @@ const SystemSetting = () => {
case 'TurnstileCheckEnabled': case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled': case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled': case 'RegisterEnabled':
case 'SMTPAuthLoginEnabled':
value = inputs[key] === 'true' ? 'false' : 'true'; value = inputs[key] === 'true' ? 'false' : 'true';
break; break;
default: default:
@@ -107,7 +109,7 @@ const SystemSetting = () => {
} }
if ( if (
name === 'Notice' || name === 'Notice' ||
name.startsWith('SMTP') || (name.startsWith('SMTP') && !name.endsWith('Enabled')) ||
name === 'ServerAddress' || name === 'ServerAddress' ||
name === 'GitHubClientId' || name === 'GitHubClientId' ||
name === 'GitHubClientSecret' || name === 'GitHubClientSecret' ||
@@ -444,6 +446,12 @@ const SystemSetting = () => {
checked={inputs.RegisterEnabled === 'true'} checked={inputs.RegisterEnabled === 'true'}
placeholder='敏感信息不会发送到前端显示' placeholder='敏感信息不会发送到前端显示'
/> />
<Form.Checkbox
checked={inputs.SMTPAuthLoginEnabled === 'true'}
label='使用 SMTP LOGIN 认证方式'
name='SMTPAuthLoginEnabled'
onChange={handleInputChange}
/>
</Form.Group> </Form.Group>
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button> <Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
<Divider /> <Divider />