mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-25 19:03:43 +08:00
Compare commits
63 Commits
v0.5.4-alp
...
v0.5.7-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
824444244b | ||
|
|
fbe9985f57 | ||
|
|
a27a5bcc06 | ||
|
|
e28d4b1741 | ||
|
|
f073592d39 | ||
|
|
fa41ca9805 | ||
|
|
e338de45b6 | ||
|
|
114587b46f | ||
|
|
b4b4acc288 | ||
|
|
d663de3e3a | ||
|
|
a85ecace2e | ||
|
|
fbdea91ea1 | ||
|
|
8d34b7a77e | ||
|
|
cbd62011b8 | ||
|
|
4701897e2e | ||
|
|
0f6c132a80 | ||
|
|
3cac45dc85 | ||
|
|
47c08c72ce | ||
|
|
53b2cace0b | ||
|
|
f0fc991b44 | ||
|
|
594f06e7b0 | ||
|
|
197d1d7a9d | ||
|
|
f9b748c2ca | ||
|
|
fd98463611 | ||
|
|
f5a1cd3463 | ||
|
|
8651451e53 | ||
|
|
1c5bb97a42 | ||
|
|
de868e4e4e | ||
|
|
1d258cc898 | ||
|
|
37e09d764c | ||
|
|
159b9e3369 | ||
|
|
92001986db | ||
|
|
a5647b1ea7 | ||
|
|
215e54fc96 | ||
|
|
ecf8a6d875 | ||
|
|
24df3e5f62 | ||
|
|
12ef9679a7 | ||
|
|
328aa68255 | ||
|
|
4335f005a6 | ||
|
|
fe26a1448d | ||
|
|
42451d9d02 | ||
|
|
25c4c111ab | ||
|
|
0d50ad4b2b | ||
|
|
959bcdef88 | ||
|
|
39ae8075e4 | ||
|
|
b57a0eca16 | ||
|
|
1b4cc78890 | ||
|
|
420c375140 | ||
|
|
01863d3e44 | ||
|
|
d0a0e871e1 | ||
|
|
bd6fe1e93c | ||
|
|
c55bb67818 | ||
|
|
0f949c3782 | ||
|
|
a721a5b6f9 | ||
|
|
276163affd | ||
|
|
621eb91b46 | ||
|
|
7e575abb95 | ||
|
|
9db93316c4 | ||
|
|
c3dc315e75 | ||
|
|
04acdb1ccb | ||
|
|
f0d5e102a3 | ||
|
|
abbf2fded0 | ||
|
|
ef2c5abb5b |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@ upload
|
|||||||
*.exe
|
*.exe
|
||||||
*.db
|
*.db
|
||||||
build
|
build
|
||||||
*.db-journal
|
*.db-journal
|
||||||
|
logs
|
||||||
94
README.md
94
README.md
@@ -59,6 +59,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
> **Warning**
|
> **Warning**
|
||||||
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
|
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
|
||||||
|
|
||||||
|
> **Warning**
|
||||||
|
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
|
||||||
|
|
||||||
## 功能
|
## 功能
|
||||||
1. 支持多种大模型:
|
1. 支持多种大模型:
|
||||||
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||||
@@ -69,12 +72,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||||
+ [x] [360 智脑](https://ai.360.cn)
|
+ [x] [360 智脑](https://ai.360.cn)
|
||||||
|
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
||||||
2. 支持配置镜像以及众多第三方代理服务:
|
2. 支持配置镜像以及众多第三方代理服务:
|
||||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||||
|
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
||||||
+ [x] [API2D](https://api2d.com/r/197971)
|
+ [x] [API2D](https://api2d.com/r/197971)
|
||||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||||
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
|
||||||
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
||||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||||
@@ -91,23 +95,32 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
15. 支持模型映射,重定向用户的请求模型。
|
15. 支持模型映射,重定向用户的请求模型。
|
||||||
16. 支持失败自动重试。
|
16. 支持失败自动重试。
|
||||||
17. 支持绘图接口。
|
17. 支持绘图接口。
|
||||||
18. 支持丰富的**自定义**设置,
|
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
|
||||||
|
19. 支持丰富的**自定义**设置,
|
||||||
1. 支持自定义系统名称,logo 以及页脚。
|
1. 支持自定义系统名称,logo 以及页脚。
|
||||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||||
19. 支持通过系统访问令牌访问管理 API。
|
20. 支持通过系统访问令牌访问管理 API。
|
||||||
20. 支持 Cloudflare Turnstile 用户校验。
|
21. 支持 Cloudflare Turnstile 用户校验。
|
||||||
21. 支持用户管理,支持**多种用户登录注册方式**:
|
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
|
```shell
|
||||||
|
# 使用 SQLite 的部署命令:
|
||||||
|
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
|
||||||
|
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
|
||||||
|
# 例如:
|
||||||
|
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
|
||||||
|
```
|
||||||
|
|
||||||
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
|
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
|
||||||
|
|
||||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
||||||
|
|
||||||
|
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
|
||||||
|
|
||||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
|
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
|
||||||
|
|
||||||
@@ -209,6 +222,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
|
|
||||||
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
|
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
|
||||||
|
|
||||||
|
#### QChatGPT - QQ机器人
|
||||||
|
项目主页:https://github.com/RockChinQ/QChatGPT
|
||||||
|
|
||||||
|
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
|
||||||
|
|
||||||
|
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
|
||||||
|
|
||||||
### 部署到第三方平台
|
### 部署到第三方平台
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>部署到 Sealos </strong></summary>
|
<summary><strong>部署到 Sealos </strong></summary>
|
||||||
@@ -227,7 +247,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
<summary><strong>部署到 Zeabur</strong></summary>
|
<summary><strong>部署到 Zeabur</strong></summary>
|
||||||
<div>
|
<div>
|
||||||
|
|
||||||
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。
|
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
|
||||||
|
|
||||||
1. 首先 fork 一份代码。
|
1. 首先 fork 一份代码。
|
||||||
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
|
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
|
||||||
@@ -242,6 +262,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
</div>
|
</div>
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>部署到 Render</strong></summary>
|
||||||
|
<div>
|
||||||
|
|
||||||
|
> Render 提供免费额度,绑卡后可以进一步提升额度
|
||||||
|
|
||||||
|
Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
## 配置
|
## 配置
|
||||||
系统本身开箱即用。
|
系统本身开箱即用。
|
||||||
|
|
||||||
@@ -260,13 +291,20 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
|
|
||||||
注意,具体的 API Base 的格式取决于你所使用的客户端。
|
注意,具体的 API Base 的格式取决于你所使用的客户端。
|
||||||
|
|
||||||
|
例如对于 OpenAI 的官方库:
|
||||||
|
```bash
|
||||||
|
OPENAI_API_KEY="sk-xxxxxx"
|
||||||
|
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
|
||||||
|
```
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
graph LR
|
graph LR
|
||||||
A(用户)
|
A(用户)
|
||||||
A --->|请求| B(One API)
|
A --->|使用 One API 分发的 key 进行请求| B(One API)
|
||||||
B -->|中继请求| C(OpenAI)
|
B -->|中继请求| C(OpenAI)
|
||||||
B -->|中继请求| D(Azure)
|
B -->|中继请求| D(Azure)
|
||||||
B -->|中继请求| E(其他下游渠道)
|
B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
|
||||||
|
B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
|
||||||
```
|
```
|
||||||
|
|
||||||
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
|
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
|
||||||
@@ -275,8 +313,9 @@ graph LR
|
|||||||
不加的话将会使用负载均衡的方式使用多个渠道。
|
不加的话将会使用负载均衡的方式使用多个渠道。
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。
|
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||||
|
+ 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
|
||||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
|
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
|
||||||
+ 例子:`SESSION_SECRET=random_string`
|
+ 例子:`SESSION_SECRET=random_string`
|
||||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
|
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
|
||||||
@@ -293,21 +332,31 @@ graph LR
|
|||||||
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
|
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
|
||||||
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
|
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
|
||||||
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
|
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
|
||||||
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
|
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
|
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||||
|
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
|
||||||
+ 例子:`SYNC_FREQUENCY=60`
|
+ 例子:`SYNC_FREQUENCY=60`
|
||||||
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
|
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
|
||||||
+ 例子:`NODE_TYPE=slave`
|
+ 例子:`NODE_TYPE=slave`
|
||||||
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
|
||||||
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
|
||||||
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
|
||||||
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
|
||||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
|
||||||
+ 例子:`POLLING_INTERVAL=5`
|
+ 例子:`POLLING_INTERVAL=5`
|
||||||
|
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
|
+ 例子:`BATCH_UPDATE_ENABLED=true`
|
||||||
|
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
|
||||||
|
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
|
||||||
|
+ 例子:`BATCH_UPDATE_INTERVAL=5`
|
||||||
|
13. 请求频率限制:
|
||||||
|
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
||||||
|
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
+ 例子:`--port 3000`
|
+ 例子:`--port 3000`
|
||||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。
|
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
|
||||||
+ 例子:`--log-dir ./logs`
|
+ 例子:`--log-dir ./logs`
|
||||||
3. `--version`: 打印系统版本号并退出。
|
3. `--version`: 打印系统版本号并退出。
|
||||||
4. `--help`: 查看命令的使用帮助和参数说明。
|
4. `--help`: 查看命令的使用帮助和参数说明。
|
||||||
@@ -339,8 +388,15 @@ https://openai.justsong.cn
|
|||||||
5. ChatGPT Next Web 报错:`Failed to fetch`
|
5. ChatGPT Next Web 报错:`Failed to fetch`
|
||||||
+ 部署的时候不要设置 `BASE_URL`。
|
+ 部署的时候不要设置 `BASE_URL`。
|
||||||
+ 检查你的接口地址和 API Key 有没有填对。
|
+ 检查你的接口地址和 API Key 有没有填对。
|
||||||
|
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
||||||
6. 报错:`当前分组负载已饱和,请稍后再试`
|
6. 报错:`当前分组负载已饱和,请稍后再试`
|
||||||
+ 上游通道 429 了。
|
+ 上游通道 429 了。
|
||||||
|
7. 升级之后我的数据会丢失吗?
|
||||||
|
+ 如果使用 MySQL,不会。
|
||||||
|
+ 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
|
||||||
|
8. 升级之前数据库需要做变更吗?
|
||||||
|
+ 一般情况下不需要,系统将在初始化的时候自动调整。
|
||||||
|
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
|
||||||
|
|
||||||
## 相关项目
|
## 相关项目
|
||||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
||||||
@@ -352,4 +408,4 @@ https://openai.justsong.cn
|
|||||||
|
|
||||||
同样适用于基于本项目的二开项目。
|
同样适用于基于本项目的二开项目。
|
||||||
|
|
||||||
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
|
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ var EmailDomainWhitelist = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
|
var DebugEnabled = os.Getenv("DEBUG") == "true"
|
||||||
|
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
||||||
|
|
||||||
var LogConsumeEnabled = true
|
var LogConsumeEnabled = true
|
||||||
|
|
||||||
@@ -92,7 +93,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
||||||
|
|
||||||
|
var BatchUpdateEnabled = false
|
||||||
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
@@ -111,10 +119,10 @@ var (
|
|||||||
// All duration's unit is seconds
|
// All duration's unit is seconds
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
var (
|
var (
|
||||||
GlobalApiRateLimitNum = 180
|
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
GlobalWebRateLimitNum = 60
|
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
UploadRateLimitNum = 10
|
||||||
@@ -148,55 +156,62 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChannelStatusUnknown = 0
|
ChannelStatusUnknown = 0
|
||||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||||
ChannelStatusDisabled = 2 // also don't use 0
|
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||||
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChannelTypeUnknown = 0
|
ChannelTypeUnknown = 0
|
||||||
ChannelTypeOpenAI = 1
|
ChannelTypeOpenAI = 1
|
||||||
ChannelTypeAPI2D = 2
|
ChannelTypeAPI2D = 2
|
||||||
ChannelTypeAzure = 3
|
ChannelTypeAzure = 3
|
||||||
ChannelTypeCloseAI = 4
|
ChannelTypeCloseAI = 4
|
||||||
ChannelTypeOpenAISB = 5
|
ChannelTypeOpenAISB = 5
|
||||||
ChannelTypeOpenAIMax = 6
|
ChannelTypeOpenAIMax = 6
|
||||||
ChannelTypeOhMyGPT = 7
|
ChannelTypeOhMyGPT = 7
|
||||||
ChannelTypeCustom = 8
|
ChannelTypeCustom = 8
|
||||||
ChannelTypeAILS = 9
|
ChannelTypeAILS = 9
|
||||||
ChannelTypeAIProxy = 10
|
ChannelTypeAIProxy = 10
|
||||||
ChannelTypePaLM = 11
|
ChannelTypePaLM = 11
|
||||||
ChannelTypeAPI2GPT = 12
|
ChannelTypeAPI2GPT = 12
|
||||||
ChannelTypeAIGC2D = 13
|
ChannelTypeAIGC2D = 13
|
||||||
ChannelTypeAnthropic = 14
|
ChannelTypeAnthropic = 14
|
||||||
ChannelTypeBaidu = 15
|
ChannelTypeBaidu = 15
|
||||||
ChannelTypeZhipu = 16
|
ChannelTypeZhipu = 16
|
||||||
ChannelTypeAli = 17
|
ChannelTypeAli = 17
|
||||||
ChannelTypeXunfei = 18
|
ChannelTypeXunfei = 18
|
||||||
ChannelType360 = 19
|
ChannelType360 = 19
|
||||||
ChannelTypeOpenRouter = 20
|
ChannelTypeOpenRouter = 20
|
||||||
|
ChannelTypeAIProxyLibrary = 21
|
||||||
|
ChannelTypeFastGPT = 22
|
||||||
|
ChannelTypeTencent = 23
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
"", // 0
|
"", // 0
|
||||||
"https://api.openai.com", // 1
|
"https://api.openai.com", // 1
|
||||||
"https://oa.api2d.net", // 2
|
"https://oa.api2d.net", // 2
|
||||||
"", // 3
|
"", // 3
|
||||||
"https://api.closeai-proxy.xyz", // 4
|
"https://api.closeai-proxy.xyz", // 4
|
||||||
"https://api.openai-sb.com", // 5
|
"https://api.openai-sb.com", // 5
|
||||||
"https://api.openaimax.com", // 6
|
"https://api.openaimax.com", // 6
|
||||||
"https://api.ohmygpt.com", // 7
|
"https://api.ohmygpt.com", // 7
|
||||||
"", // 8
|
"", // 8
|
||||||
"https://api.caipacity.com", // 9
|
"https://api.caipacity.com", // 9
|
||||||
"https://api.aiproxy.io", // 10
|
"https://api.aiproxy.io", // 10
|
||||||
"", // 11
|
"", // 11
|
||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
"https://api.anthropic.com", // 14
|
"https://api.anthropic.com", // 14
|
||||||
"https://aip.baidubce.com", // 15
|
"https://aip.baidubce.com", // 15
|
||||||
"https://open.bigmodel.cn", // 16
|
"https://open.bigmodel.cn", // 16
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
"", // 18
|
"", // 18
|
||||||
"https://ai.360.cn", // 19
|
"https://ai.360.cn", // 19
|
||||||
"https://openrouter.ai/api", // 20
|
"https://openrouter.ai/api", // 20
|
||||||
|
"https://api.aiproxy.io", // 21
|
||||||
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
|
"https://hunyuan.cloud.tencent.com", //23
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ var (
|
|||||||
Port = flag.Int("port", 3000, "the listening port")
|
Port = flag.Int("port", 3000, "the listening port")
|
||||||
PrintVersion = flag.Bool("version", false, "print version and exit")
|
PrintVersion = flag.Bool("version", false, "print version and exit")
|
||||||
PrintHelp = flag.Bool("help", false, "print help and exit")
|
PrintHelp = flag.Bool("help", false, "print help and exit")
|
||||||
LogDir = flag.String("log-dir", "", "specify the log directory")
|
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
|
||||||
)
|
)
|
||||||
|
|
||||||
func printHelp() {
|
func printHelp() {
|
||||||
|
|||||||
@@ -1,29 +1,47 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupGinLog() {
|
const (
|
||||||
|
loggerINFO = "INFO"
|
||||||
|
loggerWarn = "WARN"
|
||||||
|
loggerError = "ERR"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxLogCount = 1000000
|
||||||
|
|
||||||
|
var logCount int
|
||||||
|
var setupLogLock sync.Mutex
|
||||||
|
var setupLogWorking bool
|
||||||
|
|
||||||
|
func SetupLogger() {
|
||||||
if *LogDir != "" {
|
if *LogDir != "" {
|
||||||
commonLogPath := filepath.Join(*LogDir, "common.log")
|
ok := setupLogLock.TryLock()
|
||||||
errorLogPath := filepath.Join(*LogDir, "error.log")
|
if !ok {
|
||||||
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
log.Println("setup log is already working")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
setupLogLock.Unlock()
|
||||||
|
setupLogWorking = false
|
||||||
|
}()
|
||||||
|
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
|
||||||
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to open log file")
|
log.Fatal("failed to open log file")
|
||||||
}
|
}
|
||||||
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||||
if err != nil {
|
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||||
log.Fatal("failed to open log file")
|
|
||||||
}
|
|
||||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
|
|
||||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,6 +55,36 @@ func SysError(s string) {
|
|||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LogInfo(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerINFO, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogWarn(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerWarn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogError(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerError, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHelper(ctx context.Context, level string, msg string) {
|
||||||
|
writer := gin.DefaultErrorWriter
|
||||||
|
if level == loggerINFO {
|
||||||
|
writer = gin.DefaultWriter
|
||||||
|
}
|
||||||
|
id := ctx.Value(RequestIdKey)
|
||||||
|
now := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
|
logCount++ // we don't need accurate count, so no lock here
|
||||||
|
if logCount > maxLogCount && !setupLogWorking {
|
||||||
|
logCount = 0
|
||||||
|
setupLogWorking = true
|
||||||
|
go func() {
|
||||||
|
SetupLogger()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
func FatalLog(v ...any) {
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"gpt-3.5-turbo-0613": 0.75,
|
"gpt-3.5-turbo-0613": 0.75,
|
||||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||||
|
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||||
"text-ada-001": 0.2,
|
"text-ada-001": 0.2,
|
||||||
"text-babbage-001": 0.25,
|
"text-babbage-001": 0.25,
|
||||||
"text-curie-001": 1,
|
"text-curie-001": 1,
|
||||||
@@ -50,14 +51,15 @@ var ModelRatio = map[string]float64{
|
|||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
|
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"qwen-plus-v1": 0.5715, // Same as above
|
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||||
"SparkDesk": 0.8572, // TBD
|
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||||
|
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
|
|||||||
@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
|
|||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTimeString() string {
|
||||||
|
now := time.Now()
|
||||||
|
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||||
|
}
|
||||||
|
|
||||||
func Max(a int, b int) int {
|
func Max(a int, b int) int {
|
||||||
if a >= b {
|
if a >= b {
|
||||||
return a
|
return a
|
||||||
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
|
|||||||
}
|
}
|
||||||
return num
|
return num
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MessageWithRequestId(message string, id string) string {
|
||||||
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "upstream_error",
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"error": openAIError,
|
"error": openAIError,
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
||||||
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
|
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if channel.BaseURL == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case common.ChannelTypeOpenAI:
|
||||||
if channel.BaseURL != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.BaseURL
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case common.ChannelTypeCustom:
|
||||||
baseURL = channel.BaseURL
|
baseURL = channel.GetBaseURL()
|
||||||
case common.ChannelTypeCloseAI:
|
case common.ChannelTypeCloseAI:
|
||||||
return updateChannelCloseAIBalance(channel)
|
return updateChannelCloseAIBalance(channel)
|
||||||
case common.ChannelTypeOpenAISB:
|
case common.ChannelTypeOpenAISB:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypePaLM:
|
case common.ChannelTypePaLM:
|
||||||
fallthrough
|
fallthrough
|
||||||
@@ -32,15 +32,20 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
|
|||||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
request.Model = "gpt-35-turbo"
|
request.Model = "gpt-35-turbo"
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
||||||
|
}
|
||||||
|
}()
|
||||||
default:
|
default:
|
||||||
request.Model = "gpt-3.5-turbo"
|
request.Model = "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
|
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
|
||||||
} else {
|
} else {
|
||||||
if channel.BaseURL != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
requestURL = channel.BaseURL
|
requestURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
requestURL += "/v1/chat/completions"
|
requestURL += "/v1/chat/completions"
|
||||||
}
|
}
|
||||||
@@ -136,7 +141,7 @@ func disableChannel(channelId int, channelName string, reason string) {
|
|||||||
if common.RootUserEmail == "" {
|
if common.RootUserEmail == "" {
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
common.RootUserEmail = model.GetRootUserEmail()
|
||||||
}
|
}
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel.CreatedTime = common.GetTimestamp()
|
channel.CreatedTime = common.GetTimestamp()
|
||||||
keys := strings.Split(channel.Key, "\n")
|
keys := strings.Split(channel.Key, "\n")
|
||||||
channels := make([]model.Channel, 0)
|
channels := make([]model.Channel, 0, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
@@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteDisabledChannel(c *gin.Context) {
|
||||||
|
rows, err := model.DeleteDisabledChannel()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": rows,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func UpdateChannel(c *gin.Context) {
|
func UpdateChannel(c *gin.Context) {
|
||||||
channel := model.Channel{}
|
channel := model.Channel{}
|
||||||
err := c.ShouldBindJSON(&channel)
|
err := c.ShouldBindJSON(&channel)
|
||||||
|
|||||||
@@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
|||||||
|
|
||||||
func GitHubOAuth(c *gin.Context) {
|
func GitHubOAuth(c *gin.Context) {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
|
state := c.Query("state")
|
||||||
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "state is empty or not same",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
username := session.Get("username")
|
username := session.Get("username")
|
||||||
if username != nil {
|
if username != nil {
|
||||||
GitHubBind(c)
|
GitHubBind(c)
|
||||||
@@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateOAuthCode(c *gin.Context) {
|
||||||
|
session := sessions.Default(c)
|
||||||
|
state := common.GetRandomString(12)
|
||||||
|
session.Set("oauth_state", state)
|
||||||
|
err := session.Save()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": state,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
username := c.Query("username")
|
username := c.Query("username")
|
||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
|
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(c *gin.Context) {
|
func GetUserLogs(c *gin.Context) {
|
||||||
@@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(c *gin.Context) {
|
func SearchAllLogs(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
logs, err := model.SearchAllLogs(keyword)
|
logs, err := model.SearchAllLogs(keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUserLogs(c *gin.Context) {
|
func SearchUserLogs(c *gin.Context) {
|
||||||
@@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
logs, err := model.SearchUserLogs(userId, keyword)
|
logs, err := model.SearchUserLogs(userId, keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsStat(c *gin.Context) {
|
func GetLogsStat(c *gin.Context) {
|
||||||
@@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
username := c.Query("username")
|
username := c.Query("username")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
|
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
@@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsSelfStat(c *gin.Context) {
|
func GetLogsSelfStat(c *gin.Context) {
|
||||||
@@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
|
|||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
|
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
@@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteHistoryLogs(c *gin.Context) {
|
||||||
|
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
|
||||||
|
if targetTimestamp == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "target timestamp is required",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
count, err := model.DeleteOldLog(targetTimestamp)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": count,
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,6 +117,15 @@ func init() {
|
|||||||
Root: "gpt-3.5-turbo-16k-0613",
|
Root: "gpt-3.5-turbo-16k-0613",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "gpt-3.5-turbo-instruct",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "gpt-3.5-turbo-instruct",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "gpt-4",
|
Id: "gpt-4",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -343,21 +352,30 @@ func init() {
|
|||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: "qwen-v1",
|
Id: "qwen-turbo",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "ali",
|
OwnedBy: "ali",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "qwen-v1",
|
Root: "qwen-turbo",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: "qwen-plus-v1",
|
Id: "qwen-plus",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "ali",
|
OwnedBy: "ali",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "qwen-plus-v1",
|
Root: "qwen-plus",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "text-embedding-v1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "ali",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "text-embedding-v1",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -406,12 +424,12 @@ func init() {
|
|||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: "360GPT_S2_V9.4",
|
Id: "hunyuan",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "360",
|
OwnedBy: "tencent",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "360GPT_S2_V9.4",
|
Root: "hunyuan",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
if option.Value == "true" && common.GitHubClientId == "" {
|
if option.Value == "true" && common.GitHubClientId == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!",
|
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
220
controller/relay-aiproxy.go
Normal file
220
controller/relay-aiproxy.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||||
|
|
||||||
|
type AIProxyLibraryRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
LibraryId string `json:"libraryId"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryError struct {
|
||||||
|
ErrCode int `json:"errCode"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryDocument struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||||
|
AIProxyLibraryError
|
||||||
|
}
|
||||||
|
|
||||||
|
type AIProxyLibraryStreamResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Finish bool `json:"finish"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||||
|
query := ""
|
||||||
|
if len(request.Messages) != 0 {
|
||||||
|
query = request.Messages[len(request.Messages)-1].Content
|
||||||
|
}
|
||||||
|
return &AIProxyLibraryRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Stream: request.Stream,
|
||||||
|
Query: query,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
||||||
|
if len(documents) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
content := "\n\n参考文档:\n"
|
||||||
|
for i, document := range documents {
|
||||||
|
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
||||||
|
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: []OpenAITextResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
return &ChatCompletionsStreamResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = response.Content
|
||||||
|
return &ChatCompletionsStreamResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: response.Model,
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var usage Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(c)
|
||||||
|
var documents []AIProxyLibraryDocument
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||||
|
documents = AIProxyLibraryResponse.Documents
|
||||||
|
}
|
||||||
|
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
response := documentsAIProxyLibrary(documents)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var AIProxyLibraryResponse AIProxyLibraryResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: AIProxyLibraryResponse.Message,
|
||||||
|
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||||
|
Code: AIProxyLibraryResponse.ErrCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
@@ -35,6 +35,29 @@ type AliChatRequest struct {
|
|||||||
Parameters AliParameters `json:"parameters,omitempty"`
|
Parameters AliParameters `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters *struct {
|
||||||
|
TextType string `json:"text_type,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbedding struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
TextIndex int `json:"text_index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AliEmbeddingResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Embeddings []AliEmbedding `json:"embeddings"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage AliUsage `json:"usage"`
|
||||||
|
AliError
|
||||||
|
}
|
||||||
|
|
||||||
type AliError struct {
|
type AliError struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
@@ -44,6 +67,7 @@ type AliError struct {
|
|||||||
type AliUsage struct {
|
type AliUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliOutput struct {
|
type AliOutput struct {
|
||||||
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||||
|
return &AliEmbeddingRequest{
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Input: struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
}{
|
||||||
|
Texts: request.ParseInput(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var aliResponse AliEmbeddingResponse
|
||||||
|
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliResponse.Code != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: aliResponse.Message,
|
||||||
|
Type: aliResponse.Code,
|
||||||
|
Param: aliResponse.RequestId,
|
||||||
|
Code: aliResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||||
|
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||||
|
Model: "text-embedding-v1",
|
||||||
|
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range response.Output.Embeddings {
|
||||||
|
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||||
|
Object: `embedding`,
|
||||||
|
Index: item.TextIndex,
|
||||||
|
Embedding: item.Embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &openAIEmbeddingResponse
|
||||||
|
}
|
||||||
|
|
||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||||
choice := OpenAITextResponseChoice{
|
choice := OpenAITextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -17,6 +19,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
|
|
||||||
@@ -29,6 +32,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
if userQuota-preConsumedQuota < 0 {
|
||||||
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
@@ -91,7 +97,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
var audioResponse AudioResponse
|
var audioResponse AudioResponse
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
quota := countTokenText(audioResponse.Text, audioModel)
|
quota := countTokenText(audioResponse.Text, audioModel)
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
@@ -106,13 +112,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}(c.Request.Context())
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
|||||||
@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
|
|||||||
}
|
}
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||||
baiduEmbeddingRequest := BaiduEmbeddingRequest{
|
return &BaiduEmbeddingRequest{
|
||||||
Input: nil,
|
Input: request.ParseInput(),
|
||||||
}
|
}
|
||||||
switch request.Input.(type) {
|
|
||||||
case string:
|
|
||||||
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
|
|
||||||
case []any:
|
|
||||||
for _, item := range request.Input.([]any) {
|
|
||||||
if str, ok := item.(string); ok {
|
|
||||||
baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &baiduEmbeddingRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
@@ -97,7 +99,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
quota := int(ratio*sizeRatio*1000) * imageRequest.N
|
quota := int(ratio*sizeRatio*1000) * imageRequest.N
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
if consumeQuota && userQuota-quota < 0 {
|
||||||
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
@@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
var textResponse ImageResponse
|
var textResponse ImageResponse
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}(c.Request.Context())
|
||||||
|
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|||||||
287
controller/relay-tencent.go
Normal file
287
controller/relay-tencent.go
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://cloud.tencent.com/document/product/1729/97732
|
||||||
|
|
||||||
|
type TencentMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatRequest struct {
|
||||||
|
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||||
|
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||||
|
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||||
|
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||||
|
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||||
|
Expired int64 `json:"expired"`
|
||||||
|
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||||
|
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||||
|
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||||
|
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||||
|
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||||
|
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||||
|
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||||
|
Stream int `json:"stream"`
|
||||||
|
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||||
|
// 输入 content 总数最大支持 3000 token。
|
||||||
|
Messages []TencentMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentResponseChoices struct {
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||||
|
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatResponse struct {
|
||||||
|
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||||
|
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||||
|
Id string `json:"id,omitempty"` // 会话 id
|
||||||
|
Usage Usage `json:"usage,omitempty"` // token 数量
|
||||||
|
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
|
Note string `json:"note,omitempty"` // 注释
|
||||||
|
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||||
|
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||||
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
|
message := request.Messages[i]
|
||||||
|
if message.Role == "system" {
|
||||||
|
messages = append(messages, TencentMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: message.Content,
|
||||||
|
})
|
||||||
|
messages = append(messages, TencentMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Okay",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messages = append(messages, TencentMessage{
|
||||||
|
Content: message.Content,
|
||||||
|
Role: message.Role,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
stream := 0
|
||||||
|
if request.Stream {
|
||||||
|
stream = 1
|
||||||
|
}
|
||||||
|
return &TencentChatRequest{
|
||||||
|
Timestamp: common.GetTimestamp(),
|
||||||
|
Expired: common.GetTimestamp() + 24*60*60,
|
||||||
|
QueryID: common.GetUUID(),
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
Stream: stream,
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Usage: response.Usage,
|
||||||
|
}
|
||||||
|
if len(response.Choices) > 0 {
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: response.Choices[0].Messages.Content,
|
||||||
|
},
|
||||||
|
FinishReason: response.Choices[0].FinishReason,
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
|
||||||
|
response := ChatCompletionsStreamResponse{
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "tencent-hunyuan",
|
||||||
|
}
|
||||||
|
if len(TencentResponse.Choices) > 0 {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||||
|
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
}
|
||||||
|
response.Choices = append(response.Choices, choice)
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||||
|
var responseText string
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(c)
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var TencentResponse TencentChatResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||||
|
if len(response.Choices) != 0 {
|
||||||
|
responseText += response.Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
return nil, responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var TencentResponse TencentChatResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if TencentResponse.Error.Code != 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: TencentResponse.Error.Message,
|
||||||
|
Code: TencentResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||||
|
parts := strings.Split(config, "|")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
err = errors.New("invalid tencent config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
secretId = parts[1]
|
||||||
|
secretKey = parts[2]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
||||||
|
params := make([]string, 0)
|
||||||
|
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||||
|
params = append(params, "secret_id="+req.SecretId)
|
||||||
|
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||||
|
params = append(params, "query_id="+req.QueryID)
|
||||||
|
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||||
|
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||||
|
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||||
|
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||||
|
|
||||||
|
var messageStr string
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||||
|
}
|
||||||
|
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||||
|
params = append(params, "messages=["+messageStr+"]")
|
||||||
|
|
||||||
|
sort.Sort(sort.StringSlice(params))
|
||||||
|
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||||
|
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||||
|
signURL := url
|
||||||
|
mac.Write([]byte(signURL))
|
||||||
|
sign := mac.Sum([]byte(nil))
|
||||||
|
return base64.StdEncoding.EncodeToString(sign)
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -22,6 +23,8 @@ const (
|
|||||||
APITypeZhipu
|
APITypeZhipu
|
||||||
APITypeAli
|
APITypeAli
|
||||||
APITypeXunfei
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
|
APITypeTencent
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpClient *http.Client
|
var httpClient *http.Client
|
||||||
@@ -36,6 +39,7 @@ func init() {
|
|||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
@@ -104,6 +108,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
apiType = APITypeAli
|
apiType = APITypeAli
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
apiType = APITypeXunfei
|
apiType = APITypeXunfei
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = APITypeAIProxyLibrary
|
||||||
|
case common.ChannelTypeTencent:
|
||||||
|
apiType = APITypeTencent
|
||||||
}
|
}
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
@@ -111,6 +119,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
if channelType == common.ChannelTypeOpenAI {
|
||||||
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||||
|
}
|
||||||
|
}
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if channelType == common.ChannelTypeAzure {
|
if channelType == common.ChannelTypeAzure {
|
||||||
@@ -171,6 +184,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||||
case APITypeAli:
|
case APITypeAli:
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||||
|
if relayMode == RelayModeEmbeddings {
|
||||||
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||||
|
}
|
||||||
|
case APITypeTencent:
|
||||||
|
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||||
}
|
}
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var completionTokens int
|
var completionTokens int
|
||||||
@@ -194,6 +214,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
if userQuota-preConsumedQuota < 0 {
|
||||||
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
@@ -202,6 +225,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
// because the user has enough quota
|
// because the user has enough quota
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||||
}
|
}
|
||||||
if consumeQuota && preConsumedQuota > 0 {
|
if consumeQuota && preConsumedQuota > 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
@@ -257,8 +281,41 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
case APITypeAli:
|
case APITypeAli:
|
||||||
aliRequest := requestOpenAI2Ali(textRequest)
|
var jsonStr []byte
|
||||||
jsonStr, err := json.Marshal(aliRequest)
|
var err error
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeEmbeddings:
|
||||||
|
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
||||||
|
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||||
|
default:
|
||||||
|
aliRequest := requestOpenAI2Ali(textRequest)
|
||||||
|
jsonStr, err = json.Marshal(aliRequest)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeTencent:
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
tencentRequest := requestOpenAI2Tencent(textRequest)
|
||||||
|
tencentRequest.AppId = appId
|
||||||
|
tencentRequest.SecretId = secretId
|
||||||
|
jsonStr, err := json.Marshal(tencentRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
sign := getTencentSign(*tencentRequest, secretKey)
|
||||||
|
c.Request.Header.Set("Authorization", sign)
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
||||||
|
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||||
|
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -302,6 +359,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if textRequest.Stream {
|
if textRequest.Stream {
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
req.Header.Set("X-DashScope-SSE", "enable")
|
||||||
}
|
}
|
||||||
|
case APITypeTencent:
|
||||||
|
req.Header.Set("Authorization", apiKey)
|
||||||
|
default:
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
@@ -321,15 +382,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if preConsumedQuota != 0 {
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
// return pre-consumed quota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
|
}
|
||||||
|
}(c.Request.Context())
|
||||||
|
}
|
||||||
return relayErrorHandler(resp)
|
return relayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var textResponse TextResponse
|
var textResponse TextResponse
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
// c.Writer.Flush()
|
// c.Writer.Flush()
|
||||||
go func() {
|
go func() {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
@@ -352,22 +421,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}(c.Request.Context())
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if isStream {
|
if isStream {
|
||||||
@@ -488,7 +556,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
err, usage := aliHandler(c, resp)
|
var err *OpenAIErrorWithStatusCode
|
||||||
|
var usage *Usage
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeEmbeddings:
|
||||||
|
err, usage = aliEmbeddingHandler(c, resp)
|
||||||
|
default:
|
||||||
|
err, usage = aliHandler(c, resp)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -498,14 +573,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case APITypeXunfei:
|
case APITypeXunfei:
|
||||||
|
auth := c.Request.Header.Get("Authorization")
|
||||||
|
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
splits := strings.Split(auth, "|")
|
||||||
|
if len(splits) != 3 {
|
||||||
|
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
var err *OpenAIErrorWithStatusCode
|
||||||
|
var usage *Usage
|
||||||
if isStream {
|
if isStream {
|
||||||
auth := c.Request.Header.Get("Authorization")
|
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
} else {
|
||||||
splits := strings.Split(auth, "|")
|
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||||
if len(splits) != 3 {
|
}
|
||||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
if err != nil {
|
||||||
}
|
return err
|
||||||
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case APITypeAIProxyLibrary:
|
||||||
|
if isStream {
|
||||||
|
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -514,7 +604,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
err, usage := aiProxyLibraryHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case APITypeTencent:
|
||||||
|
if isStream {
|
||||||
|
err, responseText := tencentStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
textResponse.Usage.PromptTokens = promptTokens
|
||||||
|
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := tencentHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||||
|
|||||||
@@ -9,44 +9,53 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stopFinishReason = "stop"
|
var stopFinishReason = "stop"
|
||||||
|
|
||||||
|
// tokenEncoderMap won't grow after initialization
|
||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
|
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||||
|
|
||||||
func InitTokenEncoders() {
|
func InitTokenEncoders() {
|
||||||
common.SysLog("initializing token encoders")
|
common.SysLog("initializing token encoders")
|
||||||
fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
|
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defaultTokenEncoder = gpt35TokenEncoder
|
||||||
|
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||||
}
|
}
|
||||||
for model, _ := range common.ModelRatio {
|
for model, _ := range common.ModelRatio {
|
||||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
if strings.HasPrefix(model, "gpt-3.5") {
|
||||||
if err != nil {
|
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||||
common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
|
} else if strings.HasPrefix(model, "gpt-4") {
|
||||||
tokenEncoderMap[model] = fallbackTokenEncoder
|
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||||
continue
|
} else {
|
||||||
|
tokenEncoderMap[model] = nil
|
||||||
}
|
}
|
||||||
tokenEncoderMap[model] = tokenEncoder
|
|
||||||
}
|
}
|
||||||
common.SysLog("token encoders initialized")
|
common.SysLog("token encoders initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
tokenEncoder, ok := tokenEncoderMap[model]
|
||||||
|
if ok && tokenEncoder != nil {
|
||||||
return tokenEncoder
|
return tokenEncoder
|
||||||
}
|
}
|
||||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
if ok {
|
||||||
if err != nil {
|
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
||||||
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
|
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||||
|
tokenEncoder = defaultTokenEncoder
|
||||||
}
|
}
|
||||||
|
tokenEncoderMap[model] = tokenEncoder
|
||||||
|
return tokenEncoder
|
||||||
}
|
}
|
||||||
tokenEncoderMap[model] = tokenEncoder
|
return defaultTokenEncoder
|
||||||
return tokenEncoder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||||
@@ -146,7 +155,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: OpenAIError{
|
||||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||||
Type: "one_api_error",
|
Type: "upstream_error",
|
||||||
Code: "bad_response_status_code",
|
Code: "bad_response_status_code",
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
Param: strconv.Itoa(resp.StatusCode),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
Content: response.Payload.Choices.Text[0].Content,
|
||||||
},
|
},
|
||||||
|
FinishReason: stopFinishReason,
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := OpenAITextResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@@ -177,33 +178,85 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
setEventStreamHeaders(c)
|
||||||
var usage Usage
|
var usage Usage
|
||||||
query := c.Request.URL.Query()
|
c.Stream(func(w io.Writer) bool {
|
||||||
apiVersion := query.Get("api-version")
|
select {
|
||||||
if apiVersion == "" {
|
case xunfeiResponse := <-dataChan:
|
||||||
apiVersion = c.GetString("api_version")
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if apiVersion == "" {
|
var usage Usage
|
||||||
apiVersion = "v1.1"
|
var content string
|
||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
var xunfeiResponse XunfeiChatResponse
|
||||||
|
stop := false
|
||||||
|
for !stop {
|
||||||
|
select {
|
||||||
|
case xunfeiResponse = <-dataChan:
|
||||||
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
case stop = <-stopChan:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
domain := "general"
|
|
||||||
if apiVersion == "v2.1" {
|
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||||
domain = "generalv2"
|
|
||||||
|
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
conn, resp, err := d.Dial(authUrl, nil)
|
||||||
if err != nil || resp.StatusCode != 101 {
|
if err != nil || resp.StatusCode != 101 {
|
||||||
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||||
err = conn.WriteJSON(data)
|
err = conn.WriteJSON(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataChan := make(chan XunfeiChatResponse)
|
dataChan := make(chan XunfeiChatResponse)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -230,61 +283,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
return dataChan, stopChan, nil
|
||||||
select {
|
|
||||||
case xunfeiResponse := <-dataChan:
|
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
||||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
||||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return nil, &usage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||||
var xunfeiResponse XunfeiChatResponse
|
query := c.Request.URL.Query()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
apiVersion := query.Get("api-version")
|
||||||
if err != nil {
|
if apiVersion == "" {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
apiVersion = c.GetString("api_version")
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
if apiVersion == "" {
|
||||||
if err != nil {
|
apiVersion = "v1.1"
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &xunfeiResponse)
|
domain := "general"
|
||||||
if err != nil {
|
if apiVersion == "v2.1" {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
domain = "generalv2"
|
||||||
}
|
}
|
||||||
if xunfeiResponse.Header.Code != 0 {
|
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||||
return &OpenAIErrorWithStatusCode{
|
return domain, authUrl
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: xunfeiResponse.Header.Message,
|
|
||||||
Type: "xunfei_error",
|
|
||||||
Param: "",
|
|
||||||
Code: xunfeiResponse.Header.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct {
|
|||||||
Functions any `json:"functions,omitempty"`
|
Functions any `json:"functions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
if r.Input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var input []string
|
||||||
|
switch r.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{r.Input.(string)}
|
||||||
|
case []any:
|
||||||
|
input = make([]string, 0, len(r.Input.([]any)))
|
||||||
|
for _, item := range r.Input.([]any) {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
input = append(input, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
@@ -177,6 +196,7 @@ func Relay(c *gin.Context) {
|
|||||||
err = relayTextHelper(c, relayMode)
|
err = relayTextHelper(c, relayMode)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
retryTimesStr := c.Query("retry")
|
retryTimesStr := c.Query("retry")
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||||
if retryTimesStr == "" {
|
if retryTimesStr == "" {
|
||||||
@@ -188,12 +208,13 @@ func Relay(c *gin.Context) {
|
|||||||
if err.StatusCode == http.StatusTooManyRequests {
|
if err.StatusCode == http.StatusTooManyRequests {
|
||||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
}
|
}
|
||||||
|
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||||
c.JSON(err.StatusCode, gin.H{
|
c.JSON(err.StatusCode, gin.H{
|
||||||
"error": err.OpenAIError,
|
"error": err.OpenAIError,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|||||||
10
go.mod
10
go.mod
@@ -15,8 +15,9 @@ require (
|
|||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/pkoukk/tiktoken-go v0.1.5
|
github.com/pkoukk/tiktoken-go v0.1.5
|
||||||
golang.org/x/crypto v0.9.0
|
golang.org/x/crypto v0.14.0
|
||||||
gorm.io/driver/mysql v1.4.3
|
gorm.io/driver/mysql v1.4.3
|
||||||
|
gorm.io/driver/postgres v1.5.2
|
||||||
gorm.io/driver/sqlite v1.4.3
|
gorm.io/driver/sqlite v1.4.3
|
||||||
gorm.io/gorm v1.25.0
|
gorm.io/gorm v1.25.0
|
||||||
)
|
)
|
||||||
@@ -52,10 +53,9 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/net v0.10.0 // indirect
|
golang.org/x/net v0.17.0 // indirect
|
||||||
golang.org/x/sys v0.8.0 // indirect
|
golang.org/x/sys v0.13.0 // indirect
|
||||||
golang.org/x/text v0.9.0 // indirect
|
golang.org/x/text v0.13.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
gorm.io/driver/postgres v1.5.2 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
17
go.sum
17
go.sum
@@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
|||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
@@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
|
|||||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
||||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
||||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||||
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
|
|
||||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
||||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
||||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||||
|
|||||||
@@ -523,5 +523,6 @@
|
|||||||
"按照如下格式输入:": "Enter in the following format:",
|
"按照如下格式输入:": "Enter in the following format:",
|
||||||
"模型版本": "Model version",
|
"模型版本": "Model version",
|
||||||
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
|
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
|
||||||
"点击查看": "click to view"
|
"点击查看": "click to view",
|
||||||
|
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
|
||||||
}
|
}
|
||||||
|
|||||||
34
main.go
34
main.go
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-contrib/sessions/cookie"
|
"github.com/gin-contrib/sessions/cookie"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -21,7 +22,7 @@ var buildFS embed.FS
|
|||||||
var indexPage []byte
|
var indexPage []byte
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
common.SetupGinLog()
|
common.SetupLogger()
|
||||||
common.SysLog("One API " + common.Version + " started")
|
common.SysLog("One API " + common.Version + " started")
|
||||||
if os.Getenv("GIN_MODE") != "debug" {
|
if os.Getenv("GIN_MODE") != "debug" {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
@@ -50,18 +51,17 @@ func main() {
|
|||||||
// Initialize options
|
// Initialize options
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
|
// for compatibility with old versions
|
||||||
|
common.MemoryCacheEnabled = true
|
||||||
|
}
|
||||||
|
if common.MemoryCacheEnabled {
|
||||||
|
common.SysLog("memory cache enabled")
|
||||||
|
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||||
model.InitChannelCache()
|
model.InitChannelCache()
|
||||||
}
|
}
|
||||||
if os.Getenv("SYNC_FREQUENCY") != "" {
|
if common.MemoryCacheEnabled {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
|
go model.SyncOptions(common.SyncFrequency)
|
||||||
if err != nil {
|
go model.SyncChannelCache(common.SyncFrequency)
|
||||||
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
|
|
||||||
}
|
|
||||||
common.SyncFrequency = frequency
|
|
||||||
go model.SyncOptions(frequency)
|
|
||||||
if common.RedisEnabled {
|
|
||||||
go model.SyncChannelCache(frequency)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
||||||
@@ -77,14 +77,20 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
|
common.BatchUpdateEnabled = true
|
||||||
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
|
model.InitBatchUpdater()
|
||||||
|
}
|
||||||
controller.InitTokenEncoders()
|
controller.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.Default()
|
server := gin.New()
|
||||||
|
server.Use(gin.Recovery())
|
||||||
// This will cause SSE not to work!!!
|
// This will cause SSE not to work!!!
|
||||||
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||||
server.Use(middleware.CORS())
|
server.Use(middleware.RequestId())
|
||||||
|
middleware.SetUpLogger(server)
|
||||||
// Initialize session store
|
// Initialize session store
|
||||||
store := cookie.NewStore([]byte(common.SessionSecret))
|
store := cookie.NewStore([]byte(common.SessionSecret))
|
||||||
server.Use(sessions.Sessions("session", store))
|
server.Use(sessions.Sessions("session", store))
|
||||||
|
|||||||
@@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
"error": gin.H{
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !model.CacheIsUserEnabled(token.UserId) {
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
if err != nil {
|
||||||
"error": gin.H{
|
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
"message": "用户已被封禁",
|
return
|
||||||
"type": "one_api_error",
|
}
|
||||||
},
|
if !userEnabled {
|
||||||
})
|
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
@@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
"error": gin.H{
|
|
||||||
"message": "普通用户不支持指定渠道",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的渠道 ID",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的渠道 ID",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||||
"error": gin.H{
|
|
||||||
"message": "该渠道已被禁用",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的请求",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
@@ -99,24 +75,23 @@ func Distribute() func(c *gin.Context) {
|
|||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||||
"error": gin.H{
|
|
||||||
"message": message,
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("model_mapping", channel.ModelMapping)
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.BaseURL)
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
c.Set("library_id", channel.Other)
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|||||||
25
middleware/logger.go
Normal file
25
middleware/logger.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetUpLogger(server *gin.Engine) {
|
||||||
|
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||||
|
var requestID string
|
||||||
|
if param.Keys != nil {
|
||||||
|
requestID = param.Keys[common.RequestIdKey].(string)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||||
|
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||||
|
requestID,
|
||||||
|
param.StatusCode,
|
||||||
|
param.Latency,
|
||||||
|
param.ClientIP,
|
||||||
|
param.Method,
|
||||||
|
param.Path,
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}
|
||||||
18
middleware/request-id.go
Normal file
18
middleware/request-id.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RequestId() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
id := common.GetTimeString() + common.GetRandomString(8)
|
||||||
|
c.Set(common.RequestIdKey, id)
|
||||||
|
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
c.Header(common.RequestIdKey, id)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
17
middleware/utils.go
Normal file
17
middleware/utils.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
common.LogError(c.Request.Context(), message)
|
||||||
|
}
|
||||||
@@ -10,15 +10,18 @@ type Ability struct {
|
|||||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
ability := Ability{}
|
ability := Ability{}
|
||||||
var err error = nil
|
var err error = nil
|
||||||
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
|
||||||
|
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
|
||||||
if common.UsingSQLite {
|
if common.UsingSQLite {
|
||||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||||
} else {
|
} else {
|
||||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
|
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -40,6 +43,7 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
Model: model,
|
Model: model,
|
||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||||
|
Priority: channel.Priority,
|
||||||
}
|
}
|
||||||
abilities = append(abilities, ability)
|
abilities = append(abilities, ability)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -103,23 +104,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheIsUserEnabled(userId int) bool {
|
func CacheIsUserEnabled(userId int) (bool, error) {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return IsUserEnabled(userId)
|
return IsUserEnabled(userId)
|
||||||
}
|
}
|
||||||
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||||
if err != nil {
|
if err == nil {
|
||||||
status := common.UserStatusDisabled
|
return enabled == "1", nil
|
||||||
if IsUserEnabled(userId) {
|
|
||||||
status = common.UserStatusEnabled
|
|
||||||
}
|
|
||||||
enabled = fmt.Sprintf("%d", status)
|
|
||||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("Redis set user enabled error: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return enabled == "1"
|
|
||||||
|
userEnabled, err := IsUserEnabled(userId)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
enabled = "0"
|
||||||
|
if userEnabled {
|
||||||
|
enabled = "1"
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("Redis set user enabled error: " + err.Error())
|
||||||
|
}
|
||||||
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
@@ -154,6 +160,17 @@ func InitChannelCache() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sort by priority
|
||||||
|
for group, model2channels := range newGroup2model2channels {
|
||||||
|
for model, channels := range model2channels {
|
||||||
|
sort.Slice(channels, func(i, j int) bool {
|
||||||
|
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||||
|
})
|
||||||
|
newGroup2model2channels[group][model] = channels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
channelSyncLock.Lock()
|
channelSyncLock.Lock()
|
||||||
group2model2channels = newGroup2model2channels
|
group2model2channels = newGroup2model2channels
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
@@ -169,7 +186,7 @@ func SyncChannelCache(frequency int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
if !common.RedisEnabled {
|
if !common.MemoryCacheEnabled {
|
||||||
return GetRandomSatisfiedChannel(group, model)
|
return GetRandomSatisfiedChannel(group, model)
|
||||||
}
|
}
|
||||||
channelSyncLock.RLock()
|
channelSyncLock.RLock()
|
||||||
@@ -178,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
if len(channels) == 0 {
|
if len(channels) == 0 {
|
||||||
return nil, errors.New("channel not found")
|
return nil, errors.New("channel not found")
|
||||||
}
|
}
|
||||||
idx := rand.Intn(len(channels))
|
endIdx := len(channels)
|
||||||
|
// choose by priority
|
||||||
|
firstChannel := channels[0]
|
||||||
|
if firstChannel.GetPriority() > 0 {
|
||||||
|
for i := range channels {
|
||||||
|
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||||
|
endIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idx := rand.Intn(endIdx)
|
||||||
return channels[idx], nil
|
return channels[idx], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,18 +11,19 @@ type Channel struct {
|
|||||||
Key string `json:"key" gorm:"not null;index"`
|
Key string `json:"key" gorm:"not null;index"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Weight int `json:"weight"`
|
Weight *uint `json:"weight" gorm:"default:0"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||||
ResponseTime int `json:"response_time"` // in milliseconds
|
ResponseTime int `json:"response_time"` // in milliseconds
|
||||||
BaseURL string `json:"base_url" gorm:"column:base_url"`
|
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||||||
Other string `json:"other"`
|
Other string `json:"other"`
|
||||||
Balance float64 `json:"balance"` // in USD
|
Balance float64 `json:"balance"` // in USD
|
||||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||||
Models string `json:"models"`
|
Models string `json:"models"`
|
||||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
@@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetPriority() int64 {
|
||||||
|
if channel.Priority == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *channel.Priority
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetBaseURL() string {
|
||||||
|
if channel.BaseURL == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *channel.BaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetModelMapping() string {
|
||||||
|
if channel.ModelMapping == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *channel.ModelMapping
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Insert() error {
|
func (channel *Channel) Insert() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(channel).Error
|
err = DB.Create(channel).Error
|
||||||
@@ -141,8 +163,26 @@ func UpdateChannelStatusById(id int, status int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelUsedQuota(id int, quota int) {
|
func UpdateChannelUsedQuota(id int, quota int) {
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateChannelUsedQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateChannelUsedQuota(id int, quota int) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
common.SysError("failed to update channel used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteChannelByStatus(status int64) (int64, error) {
|
||||||
|
result := DB.Where("status = ?", status).Delete(&Channel{})
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteDisabledChannel() (int64, error) {
|
||||||
|
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
|||||||
40
model/log.go
40
model/log.go
@@ -1,22 +1,25 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Log struct {
|
type Log struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id;index:idx_created_at_id,priority:1"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index"`
|
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||||
Type int `json:"type" gorm:"index"`
|
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Username string `json:"username" gorm:"index;default:''"`
|
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
||||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||||
ModelName string `json:"model_name" gorm:"index;default:''"`
|
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||||
Quota int `json:"quota" gorm:"default:0"`
|
Quota int `json:"quota" gorm:"default:0"`
|
||||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||||
|
ChannelId int `json:"channel" gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -44,7 +47,8 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -59,14 +63,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
|
|||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
Quota: quota,
|
Quota: quota,
|
||||||
|
ChannelId: channelId,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to record log: " + err.Error())
|
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB
|
tx = DB
|
||||||
@@ -88,6 +93,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
if endTimestamp != 0 {
|
if endTimestamp != 0 {
|
||||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||||
}
|
}
|
||||||
|
if channel != 0 {
|
||||||
|
tx = tx.Where("channel = ?", channel)
|
||||||
|
}
|
||||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
@@ -125,8 +133,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
|||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
|
||||||
tx := DB.Table("logs").Select("sum(quota)")
|
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@@ -142,12 +150,15 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
}
|
}
|
||||||
|
if channel != 0 {
|
||||||
|
tx = tx.Where("channel = ?", channel)
|
||||||
|
}
|
||||||
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
||||||
return quota
|
return quota
|
||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||||
tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
|
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@@ -166,3 +177,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||||
|
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ func InitDB() (err error) {
|
|||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
common.SysLog("database migration started")
|
||||||
err = db.AutoMigrate(&Channel{})
|
err = db.AutoMigrate(&Channel{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
}
|
}
|
||||||
token, err = CacheGetTokenByKey(key)
|
token, err = CacheGetTokenByKey(key)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
if token.Status == common.TokenStatusExhausted {
|
||||||
|
return nil, errors.New("该令牌额度已用尽")
|
||||||
|
} else if token.Status == common.TokenStatusExpired {
|
||||||
|
return nil, errors.New("该令牌已过期")
|
||||||
|
}
|
||||||
if token.Status != common.TokenStatusEnabled {
|
if token.Status != common.TokenStatusEnabled {
|
||||||
return nil, errors.New("该令牌状态不可用")
|
return nil, errors.New("该令牌状态不可用")
|
||||||
}
|
}
|
||||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||||||
token.Status = common.TokenStatusExpired
|
if !common.RedisEnabled {
|
||||||
err := token.SelectUpdate()
|
token.Status = common.TokenStatusExpired
|
||||||
if err != nil {
|
err := token.SelectUpdate()
|
||||||
common.SysError("failed to update token status" + err.Error())
|
if err != nil {
|
||||||
|
common.SysError("failed to update token status" + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌已过期")
|
return nil, errors.New("该令牌已过期")
|
||||||
}
|
}
|
||||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||||
token.Status = common.TokenStatusExhausted
|
if !common.RedisEnabled {
|
||||||
err := token.SelectUpdate()
|
// in this case, we can make sure the token is exhausted
|
||||||
if err != nil {
|
token.Status = common.TokenStatusExhausted
|
||||||
common.SysError("failed to update token status" + err.Error())
|
err := token.SelectUpdate()
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update token status" + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌额度已用尽")
|
return nil, errors.New("该令牌额度已用尽")
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
token.AccessedTime = common.GetTimestamp()
|
|
||||||
err := token.SelectUpdate()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to update token" + err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("无效的令牌")
|
return nil, errors.New("无效的令牌")
|
||||||
@@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return increaseTokenQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func increaseTokenQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||||||
|
"accessed_time": common.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
@@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decreaseTokenQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decreaseTokenQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
|
"accessed_time": common.GetTimestamp(),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
|
|||||||
return user.Role >= common.RoleAdminUser
|
return user.Role >= common.RoleAdminUser
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsUserEnabled(userId int) bool {
|
func IsUserEnabled(userId int) (bool, error) {
|
||||||
if userId == 0 {
|
if userId == 0 {
|
||||||
return false
|
return false, errors.New("user id is empty")
|
||||||
}
|
}
|
||||||
var user User
|
var user User
|
||||||
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("no such user " + err.Error())
|
return false, err
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return user.Status == common.UserStatusEnabled
|
return user.Status == common.UserStatusEnabled, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAccessToken(token string) (user *User) {
|
func ValidateAccessToken(token string) (user *User) {
|
||||||
@@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return increaseUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func increaseUserQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decreaseUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decreaseUserQuota(id int, quota int) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -293,10 +308,19 @@ func GetRootUserEmail() (email string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||||
|
if common.BatchUpdateEnabled {
|
||||||
|
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||||
|
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
"request_count": gorm.Expr("request_count + ?", 1),
|
"request_count": gorm.Expr("request_count + ?", count),
|
||||||
},
|
},
|
||||||
).Error
|
).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -304,6 +328,24 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateUserUsedQuota(id int, quota int) {
|
||||||
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
|
map[string]interface{}{
|
||||||
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
|
},
|
||||||
|
).Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update user used quota: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserRequestCount(id int, count int) {
|
||||||
|
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update user request count: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetUsernameById(id int) (username string) {
|
func GetUsernameById(id int) (username string) {
|
||||||
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
||||||
return username
|
return username
|
||||||
|
|||||||
77
model/utils.go
Normal file
77
model/utils.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BatchUpdateTypeUserQuota = iota
|
||||||
|
BatchUpdateTypeTokenQuota
|
||||||
|
BatchUpdateTypeUsedQuota
|
||||||
|
BatchUpdateTypeChannelUsedQuota
|
||||||
|
BatchUpdateTypeRequestCount
|
||||||
|
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
||||||
|
)
|
||||||
|
|
||||||
|
var batchUpdateStores []map[int]int
|
||||||
|
var batchUpdateLocks []sync.Mutex
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
|
||||||
|
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitBatchUpdater() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
|
||||||
|
batchUpdate()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func addNewRecord(type_ int, id int, value int) {
|
||||||
|
batchUpdateLocks[type_].Lock()
|
||||||
|
defer batchUpdateLocks[type_].Unlock()
|
||||||
|
if _, ok := batchUpdateStores[type_][id]; !ok {
|
||||||
|
batchUpdateStores[type_][id] = value
|
||||||
|
} else {
|
||||||
|
batchUpdateStores[type_][id] += value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func batchUpdate() {
|
||||||
|
common.SysLog("batch update started")
|
||||||
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
|
batchUpdateLocks[i].Lock()
|
||||||
|
store := batchUpdateStores[i]
|
||||||
|
batchUpdateStores[i] = make(map[int]int)
|
||||||
|
batchUpdateLocks[i].Unlock()
|
||||||
|
// TODO: maybe we can combine updates with same key?
|
||||||
|
for key, value := range store {
|
||||||
|
switch i {
|
||||||
|
case BatchUpdateTypeUserQuota:
|
||||||
|
err := increaseUserQuota(key, value)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to batch update user quota: " + err.Error())
|
||||||
|
}
|
||||||
|
case BatchUpdateTypeTokenQuota:
|
||||||
|
err := increaseTokenQuota(key, value)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to batch update token quota: " + err.Error())
|
||||||
|
}
|
||||||
|
case BatchUpdateTypeUsedQuota:
|
||||||
|
updateUserUsedQuota(key, value)
|
||||||
|
case BatchUpdateTypeRequestCount:
|
||||||
|
updateUserRequestCount(key, value)
|
||||||
|
case BatchUpdateTypeChannelUsedQuota:
|
||||||
|
updateChannelUsedQuota(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common.SysLog("batch update finished")
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||||
|
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
|
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
|
||||||
@@ -73,6 +74,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
|
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
|
||||||
channelRoute.POST("/", controller.AddChannel)
|
channelRoute.POST("/", controller.AddChannel)
|
||||||
channelRoute.PUT("/", controller.UpdateChannel)
|
channelRoute.PUT("/", controller.UpdateChannel)
|
||||||
|
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
|
||||||
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
||||||
}
|
}
|
||||||
tokenRoute := apiRouter.Group("/token")
|
tokenRoute := apiRouter.Group("/token")
|
||||||
@@ -97,6 +99,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
}
|
}
|
||||||
logRoute := apiRouter.Group("/log")
|
logRoute := apiRouter.Group("/log")
|
||||||
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
|
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
|
||||||
|
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
|
||||||
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
|
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
|
||||||
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
|
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
|
||||||
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
|
router.Use(middleware.CORS())
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth())
|
modelsRouter.Use(middleware.TokenAuth())
|
||||||
|
|||||||
@@ -283,7 +283,9 @@ function App() {
|
|||||||
</Suspense>
|
</Suspense>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<Route path='*' element={NotFound} />
|
<Route path='*' element={
|
||||||
|
<NotFound />
|
||||||
|
} />
|
||||||
</Routes>
|
</Routes>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
|
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link } from 'react-router-dom';
|
||||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
|
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
|
||||||
|
|
||||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
||||||
import { renderGroup, renderNumber } from '../helpers/render';
|
import { renderGroup, renderNumber } from '../helpers/render';
|
||||||
@@ -24,7 +24,7 @@ function renderType(type) {
|
|||||||
}
|
}
|
||||||
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
||||||
}
|
}
|
||||||
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
|
return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderBalance(type, balance) {
|
function renderBalance(type, balance) {
|
||||||
@@ -55,6 +55,7 @@ const ChannelsTable = () => {
|
|||||||
const [searchKeyword, setSearchKeyword] = useState('');
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
const [searching, setSearching] = useState(false);
|
const [searching, setSearching] = useState(false);
|
||||||
const [updatingBalance, setUpdatingBalance] = useState(false);
|
const [updatingBalance, setUpdatingBalance] = useState(false);
|
||||||
|
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
|
||||||
|
|
||||||
const loadChannels = async (startIdx) => {
|
const loadChannels = async (startIdx) => {
|
||||||
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
||||||
@@ -96,7 +97,7 @@ const ChannelsTable = () => {
|
|||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const manageChannel = async (id, action, idx) => {
|
const manageChannel = async (id, action, idx, value) => {
|
||||||
let data = { id };
|
let data = { id };
|
||||||
let res;
|
let res;
|
||||||
switch (action) {
|
switch (action) {
|
||||||
@@ -111,6 +112,23 @@ const ChannelsTable = () => {
|
|||||||
data.status = 2;
|
data.status = 2;
|
||||||
res = await API.put('/api/channel/', data);
|
res = await API.put('/api/channel/', data);
|
||||||
break;
|
break;
|
||||||
|
case 'priority':
|
||||||
|
if (value === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data.priority = parseInt(value);
|
||||||
|
res = await API.put('/api/channel/', data);
|
||||||
|
break;
|
||||||
|
case 'weight':
|
||||||
|
if (value === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data.weight = parseInt(value);
|
||||||
|
if (data.weight < 0) {
|
||||||
|
data.weight = 0;
|
||||||
|
}
|
||||||
|
res = await API.put('/api/channel/', data);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
const { success, message } = res.data;
|
const { success, message } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
@@ -135,9 +153,23 @@ const ChannelsTable = () => {
|
|||||||
return <Label basic color='green'>已启用</Label>;
|
return <Label basic color='green'>已启用</Label>;
|
||||||
case 2:
|
case 2:
|
||||||
return (
|
return (
|
||||||
<Label basic color='red'>
|
<Popup
|
||||||
已禁用
|
trigger={<Label basic color='red'>
|
||||||
</Label>
|
已禁用
|
||||||
|
</Label>}
|
||||||
|
content='本渠道被手动禁用'
|
||||||
|
basic
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case 3:
|
||||||
|
return (
|
||||||
|
<Popup
|
||||||
|
trigger={<Label basic color='yellow'>
|
||||||
|
已禁用
|
||||||
|
</Label>}
|
||||||
|
content='本渠道被程序自动禁用'
|
||||||
|
basic
|
||||||
|
/>
|
||||||
);
|
);
|
||||||
default:
|
default:
|
||||||
return (
|
return (
|
||||||
@@ -208,6 +240,17 @@ const ChannelsTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const deleteAllDisabledChannels = async () => {
|
||||||
|
const res = await API.delete(`/api/channel/disabled`);
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
showSuccess(`已删除所有禁用渠道,共计 ${data} 个`);
|
||||||
|
await refresh();
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const updateChannelBalance = async (id, name, idx) => {
|
const updateChannelBalance = async (id, name, idx) => {
|
||||||
const res = await API.get(`/api/channel/update_balance/${id}/`);
|
const res = await API.get(`/api/channel/update_balance/${id}/`);
|
||||||
const { success, message, balance } = res.data;
|
const { success, message, balance } = res.data;
|
||||||
@@ -274,7 +317,19 @@ const ChannelsTable = () => {
|
|||||||
onChange={handleKeywordChange}
|
onChange={handleKeywordChange}
|
||||||
/>
|
/>
|
||||||
</Form>
|
</Form>
|
||||||
|
{
|
||||||
|
showPrompt && (
|
||||||
|
<Message onDismiss={() => {
|
||||||
|
setShowPrompt(false);
|
||||||
|
setPromptShown("channel-test");
|
||||||
|
}}>
|
||||||
|
当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
|
||||||
|
模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。
|
||||||
|
|
||||||
|
另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
|
||||||
|
</Message>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Table basic compact size='small'>
|
<Table basic compact size='small'>
|
||||||
<Table.Header>
|
<Table.Header>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
@@ -334,6 +389,14 @@ const ChannelsTable = () => {
|
|||||||
>
|
>
|
||||||
余额
|
余额
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
|
<Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortChannel('priority');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
优先级
|
||||||
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell>操作</Table.HeaderCell>
|
<Table.HeaderCell>操作</Table.HeaderCell>
|
||||||
</Table.Row>
|
</Table.Row>
|
||||||
</Table.Header>
|
</Table.Header>
|
||||||
@@ -372,6 +435,22 @@ const ChannelsTable = () => {
|
|||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
|
<Table.Cell>
|
||||||
|
<Popup
|
||||||
|
trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => {
|
||||||
|
manageChannel(
|
||||||
|
channel.id,
|
||||||
|
'priority',
|
||||||
|
idx,
|
||||||
|
event.target.value
|
||||||
|
);
|
||||||
|
}}>
|
||||||
|
<input style={{ maxWidth: '60px' }} />
|
||||||
|
</Input>}
|
||||||
|
content='渠道选择优先级,越高越优先'
|
||||||
|
basic
|
||||||
|
/>
|
||||||
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<div>
|
<div>
|
||||||
<Button
|
<Button
|
||||||
@@ -440,7 +519,7 @@ const ChannelsTable = () => {
|
|||||||
|
|
||||||
<Table.Footer>
|
<Table.Footer>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
<Table.HeaderCell colSpan='8'>
|
<Table.HeaderCell colSpan='9'>
|
||||||
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
||||||
添加新的渠道
|
添加新的渠道
|
||||||
</Button>
|
</Button>
|
||||||
@@ -449,6 +528,20 @@ const ChannelsTable = () => {
|
|||||||
</Button>
|
</Button>
|
||||||
<Button size='small' onClick={updateAllChannelsBalance}
|
<Button size='small' onClick={updateAllChannelsBalance}
|
||||||
loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
|
loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
|
||||||
|
<Popup
|
||||||
|
trigger={
|
||||||
|
<Button size='small' loading={loading}>
|
||||||
|
删除禁用渠道
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
on='click'
|
||||||
|
flowing
|
||||||
|
hoverable
|
||||||
|
>
|
||||||
|
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
|
||||||
|
确认删除
|
||||||
|
</Button>
|
||||||
|
</Popup>
|
||||||
<Pagination
|
<Pagination
|
||||||
floated='right'
|
floated='right'
|
||||||
activePage={activePage}
|
activePage={activePage}
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ const GitHubOAuth = () => {
|
|||||||
|
|
||||||
let navigate = useNavigate();
|
let navigate = useNavigate();
|
||||||
|
|
||||||
const sendCode = async (code, count) => {
|
const sendCode = async (code, state, count) => {
|
||||||
const res = await API.get(`/api/oauth/github?code=${code}`);
|
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
if (message === 'bind') {
|
if (message === 'bind') {
|
||||||
@@ -36,13 +36,14 @@ const GitHubOAuth = () => {
|
|||||||
count++;
|
count++;
|
||||||
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||||
await new Promise((resolve) => setTimeout(resolve, count * 2000));
|
await new Promise((resolve) => setTimeout(resolve, count * 2000));
|
||||||
await sendCode(code, count);
|
await sendCode(code, state, count);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let code = searchParams.get('code');
|
let code = searchParams.get('code');
|
||||||
sendCode(code, 0).then();
|
let state = searchParams.get('state');
|
||||||
|
sendCode(code, state, 0).then();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import React, { useContext, useEffect, useState } from 'react';
|
|||||||
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
|
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
|
||||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
import { API, getLogo, showError, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
|
||||||
|
import { onGitHubOAuthClicked } from './utils';
|
||||||
|
|
||||||
const LoginForm = () => {
|
const LoginForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
@@ -31,12 +32,6 @@ const LoginForm = () => {
|
|||||||
|
|
||||||
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
|
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
|
||||||
|
|
||||||
const onGitHubOAuthClicked = () => {
|
|
||||||
window.open(
|
|
||||||
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onWeChatLoginClicked = () => {
|
const onWeChatLoginClicked = () => {
|
||||||
setShowWeChatLoginModal(true);
|
setShowWeChatLoginModal(true);
|
||||||
};
|
};
|
||||||
@@ -73,8 +68,14 @@ const LoginForm = () => {
|
|||||||
if (success) {
|
if (success) {
|
||||||
userDispatch({ type: 'login', payload: data });
|
userDispatch({ type: 'login', payload: data });
|
||||||
localStorage.setItem('user', JSON.stringify(data));
|
localStorage.setItem('user', JSON.stringify(data));
|
||||||
navigate('/');
|
if (username === 'root' && password === '123456') {
|
||||||
showSuccess('登录成功!');
|
navigate('/user/edit');
|
||||||
|
showSuccess('登录成功!');
|
||||||
|
showWarning('请立刻修改默认密码!');
|
||||||
|
} else {
|
||||||
|
navigate('/token');
|
||||||
|
showSuccess('登录成功!');
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
@@ -131,7 +132,7 @@ const LoginForm = () => {
|
|||||||
circular
|
circular
|
||||||
color='black'
|
color='black'
|
||||||
icon='github'
|
icon='github'
|
||||||
onClick={onGitHubOAuthClicked}
|
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<></>
|
<></>
|
||||||
|
|||||||
@@ -56,9 +56,10 @@ const LogsTable = () => {
|
|||||||
token_name: '',
|
token_name: '',
|
||||||
model_name: '',
|
model_name: '',
|
||||||
start_timestamp: timestamp2string(0),
|
start_timestamp: timestamp2string(0),
|
||||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
|
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||||
|
channel: ''
|
||||||
});
|
});
|
||||||
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
|
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
|
||||||
|
|
||||||
const [stat, setStat] = useState({
|
const [stat, setStat] = useState({
|
||||||
quota: 0,
|
quota: 0,
|
||||||
@@ -84,7 +85,7 @@ const LogsTable = () => {
|
|||||||
const getLogStat = async () => {
|
const getLogStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
|
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setStat(data);
|
setStat(data);
|
||||||
@@ -109,7 +110,7 @@ const LogsTable = () => {
|
|||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
if (isAdminUser) {
|
if (isAdminUser) {
|
||||||
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||||
} else {
|
} else {
|
||||||
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
}
|
}
|
||||||
@@ -205,16 +206,9 @@ const LogsTable = () => {
|
|||||||
</Header>
|
</Header>
|
||||||
<Form>
|
<Form>
|
||||||
<Form.Group>
|
<Form.Group>
|
||||||
{
|
<Form.Input fluid label={'令牌名称'} width={3} value={token_name}
|
||||||
isAdminUser && (
|
|
||||||
<Form.Input fluid label={'用户名称'} width={2} value={username}
|
|
||||||
placeholder={'可选值'} name='username'
|
|
||||||
onChange={handleInputChange} />
|
|
||||||
)
|
|
||||||
}
|
|
||||||
<Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
|
|
||||||
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
|
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
|
||||||
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
|
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
|
||||||
name='model_name'
|
name='model_name'
|
||||||
onChange={handleInputChange} />
|
onChange={handleInputChange} />
|
||||||
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
|
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
|
||||||
@@ -225,6 +219,19 @@ const LogsTable = () => {
|
|||||||
onChange={handleInputChange} />
|
onChange={handleInputChange} />
|
||||||
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
|
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
|
{
|
||||||
|
isAdminUser && <>
|
||||||
|
<Form.Group>
|
||||||
|
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
|
||||||
|
placeholder='可选值' name='channel'
|
||||||
|
onChange={handleInputChange} />
|
||||||
|
<Form.Input fluid label={'用户名称'} width={3} value={username}
|
||||||
|
placeholder={'可选值'} name='username'
|
||||||
|
onChange={handleInputChange} />
|
||||||
|
|
||||||
|
</Form.Group>
|
||||||
|
</>
|
||||||
|
}
|
||||||
</Form>
|
</Form>
|
||||||
<Table basic compact size='small'>
|
<Table basic compact size='small'>
|
||||||
<Table.Header>
|
<Table.Header>
|
||||||
@@ -238,6 +245,17 @@ const LogsTable = () => {
|
|||||||
>
|
>
|
||||||
时间
|
时间
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
|
{
|
||||||
|
isAdminUser && <Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortLog('channel');
|
||||||
|
}}
|
||||||
|
width={1}
|
||||||
|
>
|
||||||
|
渠道
|
||||||
|
</Table.HeaderCell>
|
||||||
|
}
|
||||||
{
|
{
|
||||||
isAdminUser && <Table.HeaderCell
|
isAdminUser && <Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
@@ -299,16 +317,16 @@ const LogsTable = () => {
|
|||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortLog('quota');
|
sortLog('quota');
|
||||||
}}
|
}}
|
||||||
width={2}
|
width={1}
|
||||||
>
|
>
|
||||||
消耗额度
|
额度
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortLog('content');
|
sortLog('content');
|
||||||
}}
|
}}
|
||||||
width={isAdminUser ? 4 : 5}
|
width={isAdminUser ? 4 : 6}
|
||||||
>
|
>
|
||||||
详情
|
详情
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
@@ -324,8 +342,13 @@ const LogsTable = () => {
|
|||||||
.map((log, idx) => {
|
.map((log, idx) => {
|
||||||
if (log.deleted) return <></>;
|
if (log.deleted) return <></>;
|
||||||
return (
|
return (
|
||||||
<Table.Row key={log.created_at}>
|
<Table.Row key={log.id}>
|
||||||
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
||||||
|
{
|
||||||
|
isAdminUser && (
|
||||||
|
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
|
||||||
|
)
|
||||||
|
}
|
||||||
{
|
{
|
||||||
isAdminUser && (
|
isAdminUser && (
|
||||||
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
|
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
|
||||||
@@ -345,7 +368,7 @@ const LogsTable = () => {
|
|||||||
|
|
||||||
<Table.Footer>
|
<Table.Footer>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
<Table.HeaderCell colSpan={'9'}>
|
<Table.HeaderCell colSpan={'10'}>
|
||||||
<Select
|
<Select
|
||||||
placeholder='选择明细分类'
|
placeholder='选择明细分类'
|
||||||
options={LOG_OPTIONS}
|
options={LOG_OPTIONS}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
||||||
import { API, showError, verifyJSON } from '../helpers';
|
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
|
||||||
|
|
||||||
const OperationSetting = () => {
|
const OperationSetting = () => {
|
||||||
|
let now = new Date();
|
||||||
let [inputs, setInputs] = useState({
|
let [inputs, setInputs] = useState({
|
||||||
QuotaForNewUser: 0,
|
QuotaForNewUser: 0,
|
||||||
QuotaForInviter: 0,
|
QuotaForInviter: 0,
|
||||||
@@ -20,10 +21,11 @@ const OperationSetting = () => {
|
|||||||
DisplayInCurrencyEnabled: '',
|
DisplayInCurrencyEnabled: '',
|
||||||
DisplayTokenStatEnabled: '',
|
DisplayTokenStatEnabled: '',
|
||||||
ApproximateTokenEnabled: '',
|
ApproximateTokenEnabled: '',
|
||||||
RetryTimes: 0,
|
RetryTimes: 0
|
||||||
});
|
});
|
||||||
const [originInputs, setOriginInputs] = useState({});
|
const [originInputs, setOriginInputs] = useState({});
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
|
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||||
|
|
||||||
const getOptions = async () => {
|
const getOptions = async () => {
|
||||||
const res = await API.get('/api/option/');
|
const res = await API.get('/api/option/');
|
||||||
@@ -130,6 +132,17 @@ const OperationSetting = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const deleteHistoryLogs = async () => {
|
||||||
|
console.log(inputs);
|
||||||
|
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
showSuccess(`${data} 条日志已清理!`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
showError('日志清理失败:' + message);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid columns={1}>
|
<Grid columns={1}>
|
||||||
<Grid.Column>
|
<Grid.Column>
|
||||||
@@ -179,12 +192,6 @@ const OperationSetting = () => {
|
|||||||
/>
|
/>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
<Form.Group inline>
|
<Form.Group inline>
|
||||||
<Form.Checkbox
|
|
||||||
checked={inputs.LogConsumeEnabled === 'true'}
|
|
||||||
label='启用额度消费日志记录'
|
|
||||||
name='LogConsumeEnabled'
|
|
||||||
onChange={handleInputChange}
|
|
||||||
/>
|
|
||||||
<Form.Checkbox
|
<Form.Checkbox
|
||||||
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
||||||
label='以货币形式显示额度'
|
label='以货币形式显示额度'
|
||||||
@@ -208,6 +215,28 @@ const OperationSetting = () => {
|
|||||||
submitConfig('general').then();
|
submitConfig('general').then();
|
||||||
}}>保存通用设置</Form.Button>
|
}}>保存通用设置</Form.Button>
|
||||||
<Divider />
|
<Divider />
|
||||||
|
<Header as='h3'>
|
||||||
|
日志设置
|
||||||
|
</Header>
|
||||||
|
<Form.Group inline>
|
||||||
|
<Form.Checkbox
|
||||||
|
checked={inputs.LogConsumeEnabled === 'true'}
|
||||||
|
label='启用额度消费日志记录'
|
||||||
|
name='LogConsumeEnabled'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
</Form.Group>
|
||||||
|
<Form.Group widths={4}>
|
||||||
|
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||||
|
name='history_timestamp'
|
||||||
|
onChange={(e, { name, value }) => {
|
||||||
|
setHistoryTimestamp(value);
|
||||||
|
}} />
|
||||||
|
</Form.Group>
|
||||||
|
<Form.Button onClick={() => {
|
||||||
|
deleteHistoryLogs().then();
|
||||||
|
}}>清理历史日志</Form.Button>
|
||||||
|
<Divider />
|
||||||
<Header as='h3'>
|
<Header as='h3'>
|
||||||
监控设置
|
监控设置
|
||||||
</Header>
|
</Header>
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
|
|||||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
||||||
import Turnstile from 'react-turnstile';
|
import Turnstile from 'react-turnstile';
|
||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
|
import { onGitHubOAuthClicked } from './utils';
|
||||||
|
|
||||||
const PersonalSetting = () => {
|
const PersonalSetting = () => {
|
||||||
const [userState, userDispatch] = useContext(UserContext);
|
const [userState, userDispatch] = useContext(UserContext);
|
||||||
@@ -130,12 +131,6 @@ const PersonalSetting = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const openGitHubOAuth = () => {
|
|
||||||
window.open(
|
|
||||||
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const sendVerificationCode = async () => {
|
const sendVerificationCode = async () => {
|
||||||
setDisableButton(true);
|
setDisableButton(true);
|
||||||
if (inputs.email === '') return;
|
if (inputs.email === '') return;
|
||||||
@@ -249,7 +244,7 @@ const PersonalSetting = () => {
|
|||||||
</Modal>
|
</Modal>
|
||||||
{
|
{
|
||||||
status.github_oauth && (
|
status.github_oauth && (
|
||||||
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
|
<Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ const TokensTable = () => {
|
|||||||
let nextUrl;
|
let nextUrl;
|
||||||
|
|
||||||
if (nextLink) {
|
if (nextLink) {
|
||||||
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
|
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
} else {
|
} else {
|
||||||
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
}
|
}
|
||||||
@@ -138,7 +138,7 @@ const TokensTable = () => {
|
|||||||
let defaultUrl;
|
let defaultUrl;
|
||||||
|
|
||||||
if (chatLink) {
|
if (chatLink) {
|
||||||
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
|
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
} else {
|
} else {
|
||||||
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
}
|
}
|
||||||
|
|||||||
20
web/src/components/utils.js
Normal file
20
web/src/components/utils.js
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import { API, showError } from '../helpers';
|
||||||
|
|
||||||
|
export async function getOAuthState() {
|
||||||
|
const res = await API.get('/api/oauth/state');
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
return data;
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function onGitHubOAuthClicked(github_client_id) {
|
||||||
|
const state = await getOAuthState();
|
||||||
|
if (!state) return;
|
||||||
|
window.open(
|
||||||
|
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -8,7 +8,10 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
||||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
||||||
|
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
{ key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
|
{ key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
|
||||||
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
||||||
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
||||||
|
|||||||
@@ -186,4 +186,14 @@ export const verifyJSON = (str) => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export function shouldShowPrompt(id) {
|
||||||
|
let prompt = localStorage.getItem(`prompt-${id}`);
|
||||||
|
return !prompt;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
export function setPromptShown(id) {
|
||||||
|
localStorage.setItem(`prompt-${id}`, 'true');
|
||||||
|
}
|
||||||
@@ -10,6 +10,22 @@ const MODEL_MAPPING_EXAMPLE = {
|
|||||||
'gpt-4-32k-0314': 'gpt-4-32k'
|
'gpt-4-32k-0314': 'gpt-4-32k'
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function type2secretPrompt(type) {
|
||||||
|
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
||||||
|
switch (type) {
|
||||||
|
case 15:
|
||||||
|
return '按照如下格式输入:APIKey|SecretKey';
|
||||||
|
case 18:
|
||||||
|
return '按照如下格式输入:APPID|APISecret|APIKey';
|
||||||
|
case 22:
|
||||||
|
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
||||||
|
case 23:
|
||||||
|
return '按照如下格式输入:AppId|SecretId|SecretKey';
|
||||||
|
default:
|
||||||
|
return '请输入渠道对应的鉴权密钥';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const EditChannel = () => {
|
const EditChannel = () => {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
@@ -53,7 +69,7 @@ const EditChannel = () => {
|
|||||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
||||||
break;
|
break;
|
||||||
case 17:
|
case 17:
|
||||||
localModels = ['qwen-v1', 'qwen-plus-v1'];
|
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||||
@@ -62,7 +78,10 @@ const EditChannel = () => {
|
|||||||
localModels = ['SparkDesk'];
|
localModels = ['SparkDesk'];
|
||||||
break;
|
break;
|
||||||
case 19:
|
case 19:
|
||||||
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
|
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
|
||||||
|
break;
|
||||||
|
case 23:
|
||||||
|
localModels = ['hunyuan'];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||||
@@ -160,7 +179,7 @@ const EditChannel = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let localInputs = inputs;
|
let localInputs = inputs;
|
||||||
if (localInputs.base_url.endsWith('/')) {
|
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||||
}
|
}
|
||||||
if (localInputs.type === 3 && localInputs.other === '') {
|
if (localInputs.type === 3 && localInputs.other === '') {
|
||||||
@@ -169,9 +188,6 @@ const EditChannel = () => {
|
|||||||
if (localInputs.type === 18 && localInputs.other === '') {
|
if (localInputs.type === 18 && localInputs.other === '') {
|
||||||
localInputs.other = 'v2.1';
|
localInputs.other = 'v2.1';
|
||||||
}
|
}
|
||||||
if (localInputs.model_mapping === '') {
|
|
||||||
localInputs.model_mapping = '{}';
|
|
||||||
}
|
|
||||||
let res;
|
let res;
|
||||||
localInputs.models = localInputs.models.join(',');
|
localInputs.models = localInputs.models.join(',');
|
||||||
localInputs.group = localInputs.groups.join(',');
|
localInputs.group = localInputs.groups.join(',');
|
||||||
@@ -193,6 +209,24 @@ const EditChannel = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const addCustomModel = () => {
|
||||||
|
if (customModel.trim() === '') return;
|
||||||
|
if (inputs.models.includes(customModel)) return;
|
||||||
|
let localModels = [...inputs.models];
|
||||||
|
localModels.push(customModel);
|
||||||
|
let localModelOptions = [];
|
||||||
|
localModelOptions.push({
|
||||||
|
key: customModel,
|
||||||
|
text: customModel,
|
||||||
|
value: customModel
|
||||||
|
});
|
||||||
|
setModelOptions(modelOptions => {
|
||||||
|
return [...modelOptions, ...localModelOptions];
|
||||||
|
});
|
||||||
|
setCustomModel('');
|
||||||
|
handleInputChange(null, { name: 'models', value: localModels });
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Segment loading={loading}>
|
<Segment loading={loading}>
|
||||||
@@ -295,6 +329,20 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type === 21 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='知识库 ID'
|
||||||
|
name='other'
|
||||||
|
placeholder={'请输入知识库 ID,例如:123456'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.other}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Dropdown
|
<Form.Dropdown
|
||||||
label='模型'
|
label='模型'
|
||||||
@@ -322,29 +370,19 @@ const EditChannel = () => {
|
|||||||
}}>清除所有模型</Button>
|
}}>清除所有模型</Button>
|
||||||
<Input
|
<Input
|
||||||
action={
|
action={
|
||||||
<Button type={'button'} onClick={() => {
|
<Button type={'button'} onClick={addCustomModel}>填入</Button>
|
||||||
if (customModel.trim() === '') return;
|
|
||||||
if (inputs.models.includes(customModel)) return;
|
|
||||||
let localModels = [...inputs.models];
|
|
||||||
localModels.push(customModel);
|
|
||||||
let localModelOptions = [];
|
|
||||||
localModelOptions.push({
|
|
||||||
key: customModel,
|
|
||||||
text: customModel,
|
|
||||||
value: customModel
|
|
||||||
});
|
|
||||||
setModelOptions(modelOptions => {
|
|
||||||
return [...modelOptions, ...localModelOptions];
|
|
||||||
});
|
|
||||||
setCustomModel('');
|
|
||||||
handleInputChange(null, { name: 'models', value: localModels });
|
|
||||||
}}>填入</Button>
|
|
||||||
}
|
}
|
||||||
placeholder='输入自定义模型名称'
|
placeholder='输入自定义模型名称'
|
||||||
value={customModel}
|
value={customModel}
|
||||||
onChange={(e, { value }) => {
|
onChange={(e, { value }) => {
|
||||||
setCustomModel(value);
|
setCustomModel(value);
|
||||||
}}
|
}}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === 'Enter') {
|
||||||
|
addCustomModel();
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
@@ -375,7 +413,7 @@ const EditChannel = () => {
|
|||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
|
placeholder={type2secretPrompt(inputs.type)}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
@@ -393,7 +431,7 @@ const EditChannel = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type !== 3 && inputs.type !== 8 && (
|
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='代理'
|
label='代理'
|
||||||
@@ -406,6 +444,20 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type === 22 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='私有部署地址'
|
||||||
|
name='base_url'
|
||||||
|
placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Button onClick={handleCancel}>取消</Button>
|
<Button onClick={handleCancel}>取消</Button>
|
||||||
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
|
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
|
|||||||
@@ -1,19 +1,12 @@
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { Segment, Header } from 'semantic-ui-react';
|
import { Message } from 'semantic-ui-react';
|
||||||
|
|
||||||
const NotFound = () => (
|
const NotFound = () => (
|
||||||
<>
|
<>
|
||||||
<Header
|
<Message negative>
|
||||||
block
|
<Message.Header>页面不存在</Message.Header>
|
||||||
as="h4"
|
<p>请检查你的浏览器地址是否正确</p>
|
||||||
content="404"
|
</Message>
|
||||||
attached="top"
|
|
||||||
icon="info"
|
|
||||||
className="small-icon"
|
|
||||||
/>
|
|
||||||
<Segment attached="bottom">
|
|
||||||
未找到所请求的页面
|
|
||||||
</Segment>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ const EditUser = () => {
|
|||||||
label='密码'
|
label='密码'
|
||||||
name='password'
|
name='password'
|
||||||
type={'password'}
|
type={'password'}
|
||||||
placeholder={'请输入新的密码'}
|
placeholder={'请输入新的密码,最短 8 位'}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={password}
|
value={password}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
|
|||||||
Reference in New Issue
Block a user