diff --git a/README.en.md b/README.en.md
index b1c82544..db96a858 100644
--- a/README.en.md
+++ b/README.en.md
@@ -287,7 +287,9 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct.
## Related Projects
-[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
+* [FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
+* [VChart](https://github.com/VisActor/VChart): More than just a cross-platform charting library, but also an expressive data storyteller.
+* [VMind](https://github.com/VisActor/VMind): Not just automatic, but also fantastic. Open-source solution for intelligent visualization.
## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.
diff --git a/README.md b/README.md
index 3c83c7ed..78cd1353 100644
--- a/README.md
+++ b/README.md
@@ -4,8 +4,454 @@ docker image: `ppcelery/one-api:latest`
## New Features
+<<<<<<< HEAD
- update token usage by API
- support gpt-vision
- support update user's remained quota
- support aws claude
- support openai images edits
+=======
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 部署教程
+ ·
+ 使用方法
+ ·
+ 意见反馈
+ ·
+ 截图展示
+ ·
+ 在线演示
+ ·
+ 常见问题
+ ·
+ 相关项目
+ ·
+ 赞赏支持
+
+
+> [!NOTE]
+> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
+>
+> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
+
+> [!WARNING]
+> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
+
+> [!WARNING]
+> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
+
+## 功能
+1. 支持多种大模型:
+ + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
+ + [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude)
+ + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ + [x] [Mistral 系列模型](https://mistral.ai/)
+ + [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model)
+ + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ + [x] [360 智脑](https://ai.360.cn)
+ + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ + [x] [Moonshot AI](https://platform.moonshot.cn/)
+ + [x] [百川大模型](https://platform.baichuan-ai.com)
+ + [x] [MINIMAX](https://api.minimax.chat/)
+ + [x] [Groq](https://wow.groq.com/)
+ + [x] [Ollama](https://github.com/ollama/ollama)
+ + [x] [零一万物](https://platform.lingyiwanwu.com/)
+ + [x] [阶跃星辰](https://platform.stepfun.com/)
+ + [x] [Coze](https://www.coze.com/)
+ + [x] [Cohere](https://cohere.com/)
+ + [x] [DeepSeek](https://www.deepseek.com/)
+ + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/)
+ + [x] [DeepL](https://www.deepl.com/)
+ + [x] [together.ai](https://www.together.ai/)
+2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
+3. 支持通过**负载均衡**的方式访问多个渠道。
+4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
+5. 支持**多机部署**,[详见此处](#多机部署)。
+6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。
+7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
+8. 支持**渠道管理**,批量创建渠道。
+9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
+10. 支持渠道**设置模型列表**。
+11. 支持**查看额度明细**。
+12. 支持**用户邀请奖励**。
+13. 支持以美元为单位显示额度。
+14. 支持发布公告,设置充值链接,设置新用户初始额度。
+15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。
+16. 支持失败自动重试。
+17. 支持绘图接口。
+18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
+19. 支持丰富的**自定义**设置,
+ 1. 支持自定义系统名称,logo 以及页脚。
+ 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
+20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。
+21. 支持 Cloudflare Turnstile 用户校验。
+22. 支持用户管理,支持**多种用户登录注册方式**:
+ + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ + 支持使用飞书进行授权登录。
+ + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
+23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
+24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
+
+## 部署
+### 基于 Docker 进行部署
+```shell
+# 使用 SQLite 的部署命令:
+docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
+# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
+# 例如:
+docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
+```
+
+其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
+
+数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
+
+如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
+
+如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
+
+如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
+
+更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
+
+Nginx 的参考配置:
+```
+server{
+ server_name openai.justsong.cn; # 请根据实际情况修改你的域名
+
+ location / {
+ client_max_body_size 64m;
+ proxy_http_version 1.1;
+ proxy_pass http://localhost:3000; # 请根据实际情况修改你的端口
+ proxy_set_header Host $host;
+ proxy_set_header X-Forwarded-For $remote_addr;
+ proxy_cache_bypass $http_upgrade;
+ proxy_set_header Accept-Encoding gzip;
+ proxy_read_timeout 300s; # GPT-4 需要较长的超时时间,请自行调整
+ }
+}
+```
+
+之后使用 Let's Encrypt 的 certbot 配置 HTTPS:
+```bash
+# Ubuntu 安装 certbot:
+sudo snap install --classic certbot
+sudo ln -s /snap/bin/certbot /usr/bin/certbot
+# 生成证书 & 修改 Nginx 配置
+sudo certbot --nginx
+# 根据指示进行操作
+# 重启 Nginx
+sudo service nginx restart
+```
+
+初始账号用户名为 `root`,密码为 `123456`。
+
+
+### 基于 Docker Compose 进行部署
+
+> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
+
+```shell
+# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
+docker-compose up -d
+
+# 查看部署状态
+docker-compose ps
+```
+
+### 手动部署
+1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
+ ```shell
+ git clone https://github.com/songquanpeng/one-api.git
+
+ # 构建前端
+ cd one-api/web/default
+ npm install
+ npm run build
+
+ # 构建后端
+ cd ../..
+ go mod download
+ go build -ldflags "-s -w" -o one-api
+ ````
+2. 运行:
+ ```shell
+ chmod u+x one-api
+ ./one-api --port 3000 --log-dir ./logs
+ ```
+3. 访问 [http://localhost:3000/](http://localhost:3000/) 并登录。初始账号用户名为 `root`,密码为 `123456`。
+
+更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
+
+### 多机部署
+1. 所有服务器 `SESSION_SECRET` 设置一样的值。
+2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。
+3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
+4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。
+5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
+6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。
+7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
+
+环境变量的具体使用方法详见[此处](#环境变量)。
+
+### 宝塔部署教程
+
+详见 [#175](https://github.com/songquanpeng/one-api/issues/175)。
+
+如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。
+
+### 部署第三方服务配合 One API 使用
+> 欢迎 PR 添加更多示例。
+
+#### ChatGPT Next Web
+项目主页:https://github.com/Yidadaa/ChatGPT-Next-Web
+
+```bash
+docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web
+```
+
+注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。
+
+#### ChatGPT Web
+项目主页:https://github.com/Chanzhaoyu/chatgpt-web
+
+```bash
+docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://openai.justsong.cn -e OPENAI_API_KEY=sk-xxx chenzhaoyu94/chatgpt-web
+```
+
+注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
+
+#### QChatGPT - QQ机器人
+项目主页:https://github.com/RockChinQ/QChatGPT
+
+根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
+
+可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
+
+### 部署到第三方平台
+
+部署到 Sealos
+
+
+> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。
+
+点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟):
+
+[](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
+
+
+
+
+
+部署到 Zeabur
+
+
+> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
+
+[](https://zeabur.com/templates/7Q0KO3)
+
+1. 首先 fork 一份代码。
+2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
+3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。
+4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。
+5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。
+6. Deploy 会自动开始,先取消。进入下方 Variable,添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `
:@tcp(:)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。
+7. 选择 Redeploy。
+8. 进入下方 Domains,选择一个合适的域名前缀,如 "my-one-api",最终域名为 "my-one-api.zeabur.app",也可以 CNAME 自己的域名。
+9. 等待部署完成,点击生成的域名进入 One API。
+
+
+
+
+
+部署到 Render
+
+
+> Render 提供免费额度,绑卡后可以进一步提升额度
+
+Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com
+
+
+
+
+## 配置
+系统本身开箱即用。
+
+你可以通过设置环境变量或者命令行参数进行配置。
+
+等到系统启动后,使用 `root` 用户登录系统并做进一步的配置。
+
+**Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。
+
+## 使用方法
+在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。
+
+之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。
+
+你需要在各种用到 OpenAI API 的地方设置 API Base 为你的 One API 的部署地址,例如:`https://openai.justsong.cn`,API Key 则为你在 One API 中生成的令牌。
+
+注意,具体的 API Base 的格式取决于你所使用的客户端。
+
+例如对于 OpenAI 的官方库:
+```bash
+OPENAI_API_KEY="sk-xxxxxx"
+OPENAI_API_BASE="https://:/v1"
+```
+
+```mermaid
+graph LR
+ A(用户)
+ A --->|使用 One API 分发的 key 进行请求| B(One API)
+ B -->|中继请求| C(OpenAI)
+ B -->|中继请求| D(Azure)
+ B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
+ B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
+```
+
+可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
+注意,需要是管理员用户创建的令牌才能指定渠道 ID。
+
+不加的话将会使用负载均衡的方式使用多个渠道。
+
+### 环境变量
+> One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。
+1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
+2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ + 例子:`SESSION_SECRET=random_string`
+3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
+ + 例子:
+ + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈)
+ + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。
+ + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
+ + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。
+ + 请根据你的数据库配置修改下列参数(或者保持默认值):
+ + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。
+ + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
+ + 如果报错 `Error 1040: Too many connections`,请适当减小该值。
+ + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
+4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。
+5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
+6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ + 例子:`MEMORY_CACHE_ENABLED=true`
+7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
+ + 例子:`SYNC_FREQUENCY=60`
+8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ + 例子:`NODE_TYPE=slave`
+9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
+10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+11. 例子:`CHANNEL_TEST_FREQUENCY=1440`
+12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ + 例子:`POLLING_INTERVAL=5`
+13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ + 例子:`BATCH_UPDATE_ENABLED=true`
+ + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
+14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ + 例子:`BATCH_UPDATE_INTERVAL=5`
+15. 请求频率限制:
+ + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
+16. 编码器缓存设置:
+ + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
+17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
+18. `RELAY_PROXY`:设置后使用该代理来请求 API。
+19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。
+20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。
+21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
+22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
+23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。
+24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
+25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
+26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
+27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
+28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
+
+### 命令行参数
+1. `--port `: 指定服务器监听的端口号,默认为 `3000`。
+ + 例子:`--port 3000`
+2. `--log-dir `: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
+ + 例子:`--log-dir ./logs`
+3. `--version`: 打印系统版本号并退出。
+4. `--help`: 查看命令的使用帮助和参数说明。
+
+## 演示
+### 在线演示
+注意,该演示站不提供对外服务:
+https://openai.justsong.cn
+
+### 截图展示
+
+
+
+## 常见问题
+1. 额度是什么?怎么计算的?One API 的额度计算有问题?
+ + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率)
+ + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。
+ + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。
+ + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。
+2. 账户额度足够为什么提示额度不足?
+ + 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ + 令牌额度仅供用户设置最大使用量,用户可自由设置。
+3. 提示无可用渠道?
+ + 请检查的用户分组和渠道分组设置。
+ + 以及渠道的模型设置。
+4. 渠道测试报错:`invalid character '<' looking for beginning of value`
+ + 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。
+ + 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。
+5. ChatGPT Next Web 报错:`Failed to fetch`
+ + 部署的时候不要设置 `BASE_URL`。
+ + 检查你的接口地址和 API Key 有没有填对。
+ + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
+6. 报错:`当前分组负载已饱和,请稍后再试`
+ + 上游渠道 429 了。
+7. 升级之后我的数据会丢失吗?
+ + 如果使用 MySQL,不会。
+ + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
+8. 升级之前数据库需要做变更吗?
+ + 一般情况下不需要,系统将在初始化的时候自动调整。
+ + 如果需要的话,我会在更新日志中说明,并给出脚本。
+9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`?
+ + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。
+ + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。
+
+## 相关项目
+* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
+* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
+* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。
+* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。
+
+## 注意
+
+本项目使用 MIT 协议进行开源,**在此基础上**,必须在页面底部保留署名以及指向本项目的链接。如果不想保留署名,必须首先获得授权。
+
+同样适用于基于本项目的二开项目。
+
+依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
+>>>>>>> origin/upstream/main
diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go
index 2282c4ae..dbe271bc 100644
--- a/common/ctxkey/key.go
+++ b/common/ctxkey/key.go
@@ -23,4 +23,5 @@ const (
TokenName = "token_name"
BaseURL = "base_url"
AvailableModels = "available_models"
+ KeyRequestBody = "key_request_body"
)
diff --git a/common/gin.go b/common/gin.go
index 3b32f065..bd5773d7 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -6,12 +6,11 @@ import (
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
+ "github.com/songquanpeng/one-api/common/ctxkey"
)
-const KeyRequestBody = "key_request_body"
-
func GetRequestBody(c *gin.Context) ([]byte, error) {
- requestBody, _ := c.Get(KeyRequestBody)
+ requestBody, _ := c.Get(ctxkey.KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
@@ -20,7 +19,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
return nil, errors.Wrap(err, "read request body failed")
}
_ = c.Request.Body.Close()
- c.Set(KeyRequestBody, requestBody)
+ c.Set(ctxkey.KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
diff --git a/common/render/render.go b/common/render/render.go
new file mode 100644
index 00000000..646b3777
--- /dev/null
+++ b/common/render/render.go
@@ -0,0 +1,29 @@
+package render
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
+ "strings"
+)
+
+func StringData(c *gin.Context, str string) {
+ str = strings.TrimPrefix(str, "data: ")
+ str = strings.TrimSuffix(str, "\r")
+ c.Render(-1, common.CustomEvent{Data: "data: " + str})
+ c.Writer.Flush()
+}
+
+func ObjectData(c *gin.Context, object interface{}) error {
+ jsonData, err := json.Marshal(object)
+ if err != nil {
+ return fmt.Errorf("error marshalling object: %w", err)
+ }
+ StringData(c, string(jsonData))
+ return nil
+}
+
+func Done(c *gin.Context) {
+ StringData(c, "[DONE]")
+}
diff --git a/controller/relay.go b/controller/relay.go
index 2cab8088..12a7d4bd 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -46,7 +46,7 @@ func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := relaymode.GetByPath(c.Request.URL.Path)
channelId := c.GetInt(ctxkey.ChannelId)
- userId := c.GetInt("id")
+ userId := c.GetInt(ctxkey.Id)
bizErr := relayHelper(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
diff --git a/go.mod b/go.mod
index f69de2bc..a56fdf71 100644
--- a/go.mod
+++ b/go.mod
@@ -28,7 +28,7 @@ require (
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.23.0
- golang.org/x/image v0.16.0
+ golang.org/x/image v0.18.0
gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5
@@ -92,11 +92,11 @@ require (
golang.org/x/arch v0.8.0 // indirect
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect
golang.org/x/net v0.25.0 // indirect
- golang.org/x/sync v0.6.0 // indirect
+ golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/term v0.20.0 // indirect
- golang.org/x/text v0.15.0 // indirect
- golang.org/x/tools v0.7.0 // indirect
+ golang.org/x/text v0.16.0 // indirect
+ golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index a7efca8a..c627d43f 100644
--- a/go.sum
+++ b/go.sum
@@ -93,8 +93,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
-github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
-github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932 h1:5/4TSDzpDnHQ8rKEEQBjRlYx77mHOvXu08oGchxej7o=
github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932/go.mod h1:cC6EdPbj/17GFCPDK39NRarlMI+kt+O60S12cNB5J9Y=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -208,20 +208,20 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
-golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
-golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
+golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
+golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
-golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
+golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
-golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
+golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -231,13 +231,13 @@ golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
-golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
+golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
+golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
-golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
+golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
+golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go
index 961260de..03eeefa4 100644
--- a/relay/adaptor/aiproxy/main.go
+++ b/relay/adaptor/aiproxy/main.go
@@ -14,6 +14,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -90,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
+ var documents []LibraryDocument
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -103,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
return 0, nil, nil
})
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 { // ignore blank line or wrong format
- continue
- }
- if data[:5] != "data:" {
- continue
- }
- data = data[5:]
- dataChan <- data
- }
- stopChan <- true
- }()
+
common.SetEventStreamHeaders(c)
- var documents []LibraryDocument
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var AIProxyLibraryResponse LibraryStreamResponse
- err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- if len(AIProxyLibraryResponse.Documents) != 0 {
- documents = AIProxyLibraryResponse.Documents
- }
- response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- response := documentsAIProxyLibrary(documents)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 || data[:5] != "data:" {
+ continue
}
- })
- err := resp.Body.Close()
+ data = data[5:]
+
+ var AIProxyLibraryResponse LibraryStreamResponse
+ err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+ if len(AIProxyLibraryResponse.Documents) != 0 {
+ documents = AIProxyLibraryResponse.Documents
+ }
+ response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ response := documentsAIProxyLibrary(documents)
+ err := render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ render.Done(c)
+
+ err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
+
return nil, &usage
}
diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go
index 62c3900c..62cd0f08 100644
--- a/relay/adaptor/ali/main.go
+++ b/relay/adaptor/ali/main.go
@@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -182,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
return 0, nil, nil
})
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 { // ignore blank line or wrong format
- continue
- }
- if data[:5] != "data:" {
- continue
- }
- data = data[5:]
- dataChan <- data
- }
- stopChan <- true
- }()
+
common.SetEventStreamHeaders(c)
- //lastResponseText := ""
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var aliResponse ChatResponse
- err := json.Unmarshal([]byte(data), &aliResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- if aliResponse.Usage.OutputTokens != 0 {
- usage.PromptTokens = aliResponse.Usage.InputTokens
- usage.CompletionTokens = aliResponse.Usage.OutputTokens
- usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
- }
- response := streamResponseAli2OpenAI(&aliResponse)
- if response == nil {
- return true
- }
- //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
- //lastResponseText = aliResponse.Output.Text
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 || data[:5] != "data:" {
+ continue
}
- })
+ data = data[5:]
+
+ var aliResponse ChatResponse
+ err := json.Unmarshal([]byte(data), &aliResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+ if aliResponse.Usage.OutputTokens != 0 {
+ usage.PromptTokens = aliResponse.Usage.InputTokens
+ usage.CompletionTokens = aliResponse.Usage.OutputTokens
+ usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
+ }
+ response := streamResponseAli2OpenAI(&aliResponse)
+ if response == nil {
+ continue
+ }
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go
index cadcedc8..143d1efc 100644
--- a/relay/adaptor/anthropic/constants.go
+++ b/relay/adaptor/anthropic/constants.go
@@ -5,4 +5,5 @@ var ModelList = []string{
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
+ "claude-3-5-sonnet-20240620",
}
diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go
index 829985bc..8310ba56 100644
--- a/relay/adaptor/anthropic/main.go
+++ b/relay/adaptor/anthropic/main.go
@@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@@ -169,65 +170,58 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
return 0, nil, nil
})
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 6 {
- continue
- }
- if !strings.HasPrefix(data, "data:") {
- continue
- }
- data = strings.TrimPrefix(data, "data:")
- dataChan <- data
- }
-
- stopChan <- true
- }()
common.SetEventStreamHeaders(c)
+
var usage model.Usage
var modelName string
var id string
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- // some implementations may add \r at the end of data
- data = strings.TrimSpace(data)
- var claudeResponse StreamResponse
- err := json.Unmarshal([]byte(data), &claudeResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
- if meta != nil {
- usage.PromptTokens += meta.Usage.InputTokens
- usage.CompletionTokens += meta.Usage.OutputTokens
- modelName = meta.Model
- id = fmt.Sprintf("chatcmpl-%s", meta.Id)
- return true
- }
- if response == nil {
- return true
- }
- response.Id = id
- response.Model = modelName
- response.Created = createdTime
- jsonStr, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 6 || !strings.HasPrefix(data, "data:") {
+ continue
}
- })
- _ = resp.Body.Close()
+ data = strings.TrimPrefix(data, "data:")
+ data = strings.TrimSpace(data)
+
+ var claudeResponse StreamResponse
+ err := json.Unmarshal([]byte(data), &claudeResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+
+ response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
+ if meta != nil {
+ usage.PromptTokens += meta.Usage.InputTokens
+ usage.CompletionTokens += meta.Usage.OutputTokens
+ modelName = meta.Model
+ id = fmt.Sprintf("chatcmpl-%s", meta.Id)
+ continue
+ }
+ if response == nil {
+ continue
+ }
+
+ response.Id = id
+ response.Model = modelName
+ response.Created = createdTime
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
+ err := resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
return nil, &usage
}
diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go
index 68880f7a..3c852508 100644
--- a/relay/adaptor/aws/main.go
+++ b/relay/adaptor/aws/main.go
@@ -33,12 +33,13 @@ func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
var awsModelIDMap = map[string]string{
- "claude-instant-1.2": "anthropic.claude-instant-v1",
- "claude-2.0": "anthropic.claude-v2",
- "claude-2.1": "anthropic.claude-v2:1",
- "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
- "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
- "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
+ "claude-instant-1.2": "anthropic.claude-instant-v1",
+ "claude-2.0": "anthropic.claude-v2",
+ "claude-2.1": "anthropic.claude-v2:1",
+ "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
+ "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
+ "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
}
func awsModelID(requestModel string) (string, error) {
diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go
index 52b3215e..716f4a9f 100644
--- a/relay/adaptor/baidu/main.go
+++ b/relay/adaptor/baidu/main.go
@@ -15,6 +15,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -138,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 6 { // ignore blank line or wrong format
- continue
- }
- data = data[6:]
- dataChan <- data
- }
- stopChan <- true
- }()
+ scanner.Split(bufio.ScanLines)
+
common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var baiduResponse ChatStreamResponse
- err := json.Unmarshal([]byte(data), &baiduResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- if baiduResponse.Usage.TotalTokens != 0 {
- usage.TotalTokens = baiduResponse.Usage.TotalTokens
- usage.PromptTokens = baiduResponse.Usage.PromptTokens
- usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
- }
- response := streamResponseBaidu2OpenAI(&baiduResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 6 {
+ continue
}
- })
+ data = data[6:]
+
+ var baiduResponse ChatStreamResponse
+ err := json.Unmarshal([]byte(data), &baiduResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+ if baiduResponse.Usage.TotalTokens != 0 {
+ usage.TotalTokens = baiduResponse.Usage.TotalTokens
+ usage.PromptTokens = baiduResponse.Usage.PromptTokens
+ usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
+ }
+ response := streamResponseBaidu2OpenAI(&baiduResponse)
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go
index 45d86647..c76520a2 100644
--- a/relay/adaptor/cloudflare/main.go
+++ b/relay/adaptor/cloudflare/main.go
@@ -2,8 +2,8 @@ package cloudflare
import (
"bufio"
- "bytes"
"encoding/json"
+ "github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@@ -62,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := bytes.IndexByte(data, '\n'); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
+ scanner.Split(bufio.ScanLines)
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < len("data: ") {
- continue
- }
- data = strings.TrimPrefix(data, "data: ")
- dataChan <- data
- }
- stopChan <- true
- }()
common.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
responseModel := c.GetString("original_model")
var responseText string
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- // some implementations may add \r at the end of data
- data = strings.TrimSuffix(data, "\r")
- var cloudflareResponse StreamResponse
- err := json.Unmarshal([]byte(data), &cloudflareResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
- if response == nil {
- return true
- }
- responseText += cloudflareResponse.Response
- response.Id = id
- response.Model = responseModel
- jsonStr, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < len("data: ") {
+ continue
}
- })
- _ = resp.Body.Close()
+ data = strings.TrimPrefix(data, "data: ")
+ data = strings.TrimSuffix(data, "\r")
+
+ var cloudflareResponse StreamResponse
+ err := json.Unmarshal([]byte(data), &cloudflareResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+
+ response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
+ if response == nil {
+ continue
+ }
+
+ responseText += cloudflareResponse.Response
+ response.Id = id
+ response.Model = responseModel
+
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
+ err := resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+
usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
return nil, usage
}
diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go
index 4bc3fa8d..45db437b 100644
--- a/relay/adaptor/cohere/main.go
+++ b/relay/adaptor/cohere/main.go
@@ -2,9 +2,9 @@ package cohere
import (
"bufio"
- "bytes"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := bytes.IndexByte(data, '\n'); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
+ scanner.Split(bufio.ScanLines)
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- dataChan <- data
- }
- stopChan <- true
- }()
common.SetEventStreamHeaders(c)
var usage model.Usage
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- // some implementations may add \r at the end of data
- data = strings.TrimSuffix(data, "\r")
- var cohereResponse StreamResponse
- err := json.Unmarshal([]byte(data), &cohereResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
- if meta != nil {
- usage.PromptTokens += meta.Meta.Tokens.InputTokens
- usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
- return true
- }
- if response == nil {
- return true
- }
- response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
- response.Model = c.GetString("original_model")
- response.Created = createdTime
- jsonStr, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ data = strings.TrimSuffix(data, "\r")
+
+ var cohereResponse StreamResponse
+ err := json.Unmarshal([]byte(data), &cohereResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
}
- })
- _ = resp.Body.Close()
+
+ response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
+ if meta != nil {
+ usage.PromptTokens += meta.Meta.Tokens.InputTokens
+ usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
+ continue
+ }
+ if response == nil {
+ continue
+ }
+
+ response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
+ response.Model = c.GetString("original_model")
+ response.Created = createdTime
+
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
+ err := resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+
return nil, &usage
}
diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go
index 721c5d13..d0402a76 100644
--- a/relay/adaptor/coze/main.go
+++ b/relay/adaptor/coze/main.go
@@ -4,6 +4,11 @@ import (
"bufio"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common/render"
+ "io"
+ "net/http"
+ "strings"
+
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
@@ -12,9 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
- "io"
- "net/http"
- "strings"
)
// https://www.coze.com/open
@@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var responseText string
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 {
- continue
- }
- if !strings.HasPrefix(data, "data:") {
- continue
- }
- data = strings.TrimPrefix(data, "data:")
- dataChan <- data
- }
- stopChan <- true
- }()
+ scanner.Split(bufio.ScanLines)
+
common.SetEventStreamHeaders(c)
var modelName string
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- // some implementations may add \r at the end of data
- data = strings.TrimSuffix(data, "\r")
- var cozeResponse StreamResponse
- err := json.Unmarshal([]byte(data), &cozeResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
- if response == nil {
- return true
- }
- for _, choice := range response.Choices {
- responseText += conv.AsString(choice.Delta.Content)
- }
- response.Model = modelName
- response.Created = createdTime
- jsonStr, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 || !strings.HasPrefix(data, "data:") {
+ continue
}
- })
- _ = resp.Body.Close()
+ data = strings.TrimPrefix(data, "data:")
+ data = strings.TrimSuffix(data, "\r")
+
+ var cozeResponse StreamResponse
+ err := json.Unmarshal([]byte(data), &cozeResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+
+ response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
+ if response == nil {
+ continue
+ }
+
+ for _, choice := range response.Choices {
+ responseText += conv.AsString(choice.Delta.Content)
+ }
+ response.Model = modelName
+ response.Created = createdTime
+
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
+ err := resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+
return nil, &responseText
}
diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go
index 3feb612f..8b35e5a4 100644
--- a/relay/adaptor/gemini/main.go
+++ b/relay/adaptor/gemini/main.go
@@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@@ -274,64 +275,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- data = strings.TrimSpace(data)
- if !strings.HasPrefix(data, "data: ") {
- continue
- }
- data = strings.TrimPrefix(data, "data: ")
- data = strings.TrimSuffix(data, "\"")
- dataChan <- data
- }
- stopChan <- true
- }()
+ scanner.Split(bufio.ScanLines)
+
common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var geminiResponse ChatResponse
- err := json.Unmarshal([]byte(data), &geminiResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response := streamResponseGeminiChat2OpenAI(&geminiResponse)
- if response == nil {
- return true
- }
- responseText += response.Choices[0].Delta.StringContent()
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ data = strings.TrimSpace(data)
+ if !strings.HasPrefix(data, "data: ") {
+ continue
}
- })
+ data = strings.TrimPrefix(data, "data: ")
+ data = strings.TrimSuffix(data, "\"")
+
+ var geminiResponse ChatResponse
+ err := json.Unmarshal([]byte(data), &geminiResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+
+ response := streamResponseGeminiChat2OpenAI(&geminiResponse)
+ if response == nil {
+ continue
+ }
+
+ responseText += response.Choices[0].Delta.StringContent()
+
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
+
return nil, responseText
}
diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go
index f0e41599..cc0734ba 100644
--- a/relay/adaptor/ollama/main.go
+++ b/relay/adaptor/ollama/main.go
@@ -15,6 +15,7 @@ import (
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -105,54 +106,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return 0, nil, nil
}
if i := strings.Index(string(data), "}\n"); i >= 0 {
- return i + 2, data[0:i], nil
+ return i + 2, data[0 : i+1], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := strings.TrimPrefix(scanner.Text(), "}")
- dataChan <- data + "}"
- }
- stopChan <- true
- }()
+
common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var ollamaResponse ChatResponse
- err := json.Unmarshal([]byte(data), &ollamaResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- if ollamaResponse.EvalCount != 0 {
- usage.PromptTokens = ollamaResponse.PromptEvalCount
- usage.CompletionTokens = ollamaResponse.EvalCount
- usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
- }
- response := streamResponseOllama2OpenAI(&ollamaResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := strings.TrimPrefix(scanner.Text(), "}")
+ data = data + "}"
+
+ var ollamaResponse ChatResponse
+ err := json.Unmarshal([]byte(data), &ollamaResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
}
- })
+
+ if ollamaResponse.EvalCount != 0 {
+ usage.PromptTokens = ollamaResponse.PromptEvalCount
+ usage.CompletionTokens = ollamaResponse.EvalCount
+ usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
+ }
+
+ response := streamResponseOllama2OpenAI(&ollamaResponse)
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
+
return nil, &usage
}
diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go
index 78842aab..ec538c75 100644
--- a/relay/adaptor/openai/main.go
+++ b/relay/adaptor/openai/main.go
@@ -12,6 +12,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
@@ -25,88 +26,68 @@ const (
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
+ scanner.Split(bufio.ScanLines)
var usage *model.Usage
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < dataPrefixLength { // ignore blank line or wrong format
- continue
- }
- if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
- continue
- }
- if strings.HasPrefix(data[dataPrefixLength:], done) {
- dataChan <- data
- continue
- }
- switch relayMode {
- case relaymode.ChatCompletions:
- var streamResponse ChatCompletionsStreamResponse
- err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- dataChan <- data // if error happened, pass the data to client
- continue // just ignore the error
- }
- if len(streamResponse.Choices) == 0 {
- // but for empty choice, we should not pass it to client, this is for azure
- continue // just ignore empty choice
- }
- dataChan <- data
- for _, choice := range streamResponse.Choices {
- responseText += conv.AsString(choice.Delta.Content)
- }
- if streamResponse.Usage != nil {
- usage = streamResponse.Usage
- }
- case relaymode.Completions:
- dataChan <- data
- var streamResponse CompletionsStreamResponse
- err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- continue
- }
- for _, choice := range streamResponse.Choices {
- responseText += choice.Text
- }
- }
- }
- stopChan <- true
- }()
+
common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- if strings.HasPrefix(data, "data: [DONE]") {
- data = data[:12]
- }
- // some implementations may add \r at the end of data
- data = strings.TrimSuffix(data, "\r")
- c.Render(-1, common.CustomEvent{Data: data})
- return true
- case <-stopChan:
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < dataPrefixLength { // ignore blank line or wrong format
+ continue
}
- })
+ if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
+ continue
+ }
+ if strings.HasPrefix(data[dataPrefixLength:], done) {
+ render.StringData(c, data)
+ continue
+ }
+ switch relayMode {
+ case relaymode.ChatCompletions:
+ var streamResponse ChatCompletionsStreamResponse
+ err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ render.StringData(c, data) // if error happened, pass the data to client
+ continue // just ignore the error
+ }
+ if len(streamResponse.Choices) == 0 {
+ // but for empty choice, we should not pass it to client, this is for azure
+ continue // just ignore empty choice
+ }
+ render.StringData(c, data)
+ for _, choice := range streamResponse.Choices {
+ responseText += conv.AsString(choice.Delta.Content)
+ }
+ if streamResponse.Usage != nil {
+ usage = streamResponse.Usage
+ }
+ case relaymode.Completions:
+ render.StringData(c, data)
+ var streamResponse CompletionsStreamResponse
+ err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+ for _, choice := range streamResponse.Choices {
+ responseText += choice.Text
+ }
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
}
+
return nil, responseText, usage
}
@@ -150,7 +131,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
- if textResponse.Usage.TotalTokens == 0 {
+ if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
diff --git a/relay/adaptor/palm/palm.go b/relay/adaptor/palm/palm.go
index 1e60e7cd..d31784ec 100644
--- a/relay/adaptor/palm/palm.go
+++ b/relay/adaptor/palm/palm.go
@@ -3,6 +3,10 @@ package palm
import (
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common/render"
+ "io"
+ "net/http"
+
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
@@ -11,8 +15,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
- "io"
- "net/http"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
@@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp()
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- logger.SysError("error reading stream response: " + err.Error())
- stopChan <- true
- return
- }
- err = resp.Body.Close()
- if err != nil {
- logger.SysError("error closing stream response: " + err.Error())
- stopChan <- true
- return
- }
- var palmResponse ChatResponse
- err = json.Unmarshal(responseBody, &palmResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- stopChan <- true
- return
- }
- fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
- fullTextResponse.Id = responseId
- fullTextResponse.Created = createdTime
- if len(palmResponse.Candidates) > 0 {
- responseText = palmResponse.Candidates[0].Content
- }
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- stopChan <- true
- return
- }
- dataChan <- string(jsonResponse)
- stopChan <- true
- }()
+
common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- c.Render(-1, common.CustomEvent{Data: "data: " + data})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ logger.SysError("error reading stream response: " + err.Error())
+ err := resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
- })
- err := resp.Body.Close()
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
+ }
+
+ err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
+
+ var palmResponse ChatResponse
+ err = json.Unmarshal(responseBody, &palmResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
+ }
+
+ fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
+ fullTextResponse.Id = responseId
+ fullTextResponse.Created = createdTime
+ if len(palmResponse.Candidates) > 0 {
+ responseText = palmResponse.Candidates[0].Content
+ }
+
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ logger.SysError("error marshalling stream response: " + err.Error())
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
+ }
+
+ err = render.ObjectData(c, string(jsonResponse))
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+
+ render.Done(c)
+
return nil, responseText
}
diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go
index 801e11db..264d6a37 100644
--- a/relay/adaptor/tencent/main.go
+++ b/relay/adaptor/tencent/main.go
@@ -20,6 +20,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -88,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 { // ignore blank line or wrong format
- continue
- }
- if data[:5] != "data:" {
- continue
- }
- data = data[5:]
- dataChan <- data
- }
- stopChan <- true
- }()
+ scanner.Split(bufio.ScanLines)
+
common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var TencentResponse ChatResponse
- err := json.Unmarshal([]byte(data), &TencentResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response := streamResponseTencent2OpenAI(&TencentResponse)
- if len(response.Choices) != 0 {
- responseText += conv.AsString(response.Choices[0].Delta.Content)
- }
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 || !strings.HasPrefix(data, "data:") {
+ continue
}
- })
+ data = strings.TrimPrefix(data, "data:")
+
+ var tencentResponse ChatResponse
+ err := json.Unmarshal([]byte(data), &tencentResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+
+ response := streamResponseTencent2OpenAI(&tencentResponse)
+ if len(response.Choices) != 0 {
+ responseText += conv.AsString(response.Choices[0].Delta.Content)
+ }
+
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
+
return nil, responseText
}
diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go
index 31dcec71..12a56210 100644
--- a/relay/adaptor/xunfei/constants.go
+++ b/relay/adaptor/xunfei/constants.go
@@ -6,4 +6,5 @@ var ModelList = []string{
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.5",
+ "SparkDesk-v4.0",
}
diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go
index 39b76e27..ef6120e5 100644
--- a/relay/adaptor/xunfei/main.go
+++ b/relay/adaptor/xunfei/main.go
@@ -44,7 +44,7 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
- if strings.HasPrefix(domain, "generalv3") {
+ if strings.HasPrefix(domain, "generalv3") || domain == "4.0Ultra" {
functions := make([]model.Function, len(request.Tools))
for i, tool := range request.Tools {
functions[i] = tool.Function
@@ -290,6 +290,8 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3"
case "v3.5":
return "generalv3.5"
+ case "v4.0":
+ return "4.0Ultra"
}
return "general" + apiVersion
}
diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go
index 3880b205..0489136e 100644
--- a/relay/adaptor/zhipu/main.go
+++ b/relay/adaptor/zhipu/main.go
@@ -14,6 +14,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -156,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
return 0, nil, nil
})
- dataChan := make(chan string)
- metaChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- lines := strings.Split(data, "\n")
- for i, line := range lines {
- if len(line) < 5 {
+
+ common.SetEventStreamHeaders(c)
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ lines := strings.Split(data, "\n")
+ for i, line := range lines {
+ if len(line) < 5 {
+ continue
+ }
+ if strings.HasPrefix(line, "data:") {
+ dataSegment := line[5:]
+ if i != len(lines)-1 {
+ dataSegment += "\n"
+ }
+ response := streamResponseZhipu2OpenAI(dataSegment)
+ err := render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError("error marshalling stream response: " + err.Error())
+ }
+ } else if strings.HasPrefix(line, "meta:") {
+ metaSegment := line[5:]
+ var zhipuResponse StreamMetaResponse
+ err := json.Unmarshal([]byte(metaSegment), &zhipuResponse)
+ if err != nil {
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
- if line[:5] == "data:" {
- dataChan <- line[5:]
- if i != len(lines)-1 {
- dataChan <- "\n"
- }
- } else if line[:5] == "meta:" {
- metaChan <- line[5:]
+ response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
+ err = render.ObjectData(c, response)
+ if err != nil {
+ logger.SysError("error marshalling stream response: " + err.Error())
}
+ usage = zhipuUsage
}
}
- stopChan <- true
- }()
- common.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- response := streamResponseZhipu2OpenAI(data)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case data := <-metaChan:
- var zhipuResponse StreamMetaResponse
- err := json.Unmarshal([]byte(data), &zhipuResponse)
- if err != nil {
- logger.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- logger.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- usage = zhipuUsage
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.SysError("error reading stream: " + err.Error())
+ }
+
+ render.Done(c)
+
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
+
return nil, usage
}
diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go
index f24053cd..5f1ae323 100644
--- a/relay/billing/ratio/model.go
+++ b/relay/billing/ratio/model.go
@@ -70,12 +70,13 @@ var ModelRatio = map[string]float64{
"dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
"dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
- "claude-instant-1.2": 0.8 / 1000 * USD,
- "claude-2.0": 8.0 / 1000 * USD,
- "claude-2.1": 8.0 / 1000 * USD,
- "claude-3-haiku-20240307": 0.25 / 1000 * USD,
- "claude-3-sonnet-20240229": 3.0 / 1000 * USD,
- "claude-3-opus-20240229": 15.0 / 1000 * USD,
+ "claude-instant-1.2": 0.8 / 1000 * USD,
+ "claude-2.0": 8.0 / 1000 * USD,
+ "claude-2.1": 8.0 / 1000 * USD,
+ "claude-3-haiku-20240307": 0.25 / 1000 * USD,
+ "claude-3-sonnet-20240229": 3.0 / 1000 * USD,
+ "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
+ "claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
@@ -124,6 +125,7 @@ var ModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
+ "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
diff --git a/relay/controller/helper.go b/relay/controller/helper.go
index bb38e164..b2016030 100644
--- a/relay/controller/helper.go
+++ b/relay/controller/helper.go
@@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil
}
-func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
- imageRequest := &relaymodel.ImageRequest{}
- err := common.UnmarshalBodyReusable(c, imageRequest)
- if err != nil {
- return nil, err
- }
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- if imageRequest.Model == "" {
- imageRequest.Model = "dall-e-2"
- }
- return imageRequest, nil
-}
-
-func isValidImageSize(model string, size string) bool {
- if model == "cogview-3" {
- return true
- }
- _, ok := billingratio.ImageSizeRatios[model][size]
- return ok
-}
-
-func getImageSizeRatio(model string, size string) float64 {
- ratio, ok := billingratio.ImageSizeRatios[model][size]
- if !ok {
- return 1
- }
- return ratio
-}
-
-func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
- // model validation
- hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
- if !hasValidSize {
- return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
- }
- // check prompt length
- if imageRequest.Prompt == "" {
- return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
- }
- if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
- return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
- }
- // Number of generated images validation
- if !isWithinRange(imageRequest.Model, imageRequest.N) {
- // channel not azure
- if meta.ChannelType != channeltype.Azure {
- return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
- }
- }
- return nil
-}
-
-func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
- if imageRequest == nil {
- return 0, errors.New("imageRequest is nil")
- }
- imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
- if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
- if imageRequest.Size == "1024x1024" {
- imageCostRatio *= 2
- } else {
- imageCostRatio *= 1.5
- }
- }
- return imageCostRatio, nil
-}
-
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case relaymode.ChatCompletions:
diff --git a/relay/controller/image.go b/relay/controller/image.go
index f6a7dc31..2c6900a5 100644
--- a/relay/controller/image.go
+++ b/relay/controller/image.go
@@ -11,6 +11,7 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
@@ -22,13 +23,84 @@ import (
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
-func isWithinRange(element string, value int) bool {
- if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
- return false
+func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
+ imageRequest := &relaymodel.ImageRequest{}
+ err := common.UnmarshalBodyReusable(c, imageRequest)
+ if err != nil {
+ return nil, err
}
- min := billingratio.ImageGenerationAmounts[element][0]
- max := billingratio.ImageGenerationAmounts[element][1]
- return value >= min && value <= max
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ if imageRequest.Model == "" {
+ imageRequest.Model = "dall-e-2"
+ }
+ return imageRequest, nil
+}
+
+func isValidImageSize(model string, size string) bool {
+ if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
+ return true
+ }
+ _, ok := billingratio.ImageSizeRatios[model][size]
+ return ok
+}
+
+func isValidImagePromptLength(model string, promptLength int) bool {
+ maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
+ return !ok || promptLength <= maxPromptLength
+}
+
+func isWithinRange(element string, value int) bool {
+ amounts, ok := billingratio.ImageGenerationAmounts[element]
+ return !ok || (value >= amounts[0] && value <= amounts[1])
+}
+
+func getImageSizeRatio(model string, size string) float64 {
+ if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
+ return ratio
+ }
+ return 1
+}
+
+func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
+ // check prompt length
+ if imageRequest.Prompt == "" {
+ return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
+ }
+
+ // model validation
+ if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
+ return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
+ }
+
+ if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
+ return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
+ }
+
+ // Number of generated images validation
+ if !isWithinRange(imageRequest.Model, imageRequest.N) {
+ return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
+ }
+ return nil
+}
+
+func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
+ if imageRequest == nil {
+ return 0, errors.New("imageRequest is nil")
+ }
+ imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
+ if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
+ if imageRequest.Size == "1024x1024" {
+ imageCostRatio *= 2
+ } else {
+ imageCostRatio *= 1.5
+ }
+ }
+ return imageCostRatio, nil
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js
index efb2cee8..73fd2da2 100644
--- a/web/air/src/pages/Channel/EditChannel.js
+++ b/web/air/src/pages/Channel/EditChannel.js
@@ -63,7 +63,7 @@ const EditChannel = (props) => {
let localModels = [];
switch (value) {
case 14:
- localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"];
+ localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
break;
case 11:
localModels = ['PaLM-2'];
@@ -78,7 +78,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
- localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5'];
+ localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js
index 88e1ea92..51b7c6c4 100644
--- a/web/berry/src/views/Channel/type/Config.js
+++ b/web/berry/src/views/Channel/type/Config.js
@@ -91,7 +91,7 @@ const typeConfig = {
other: '版本号'
},
input: {
- models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5']
+ models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
},
prompt: {
key: '按照如下格式输入:APPID|APISecret|APIKey',