Compare commits

...

114 Commits

Author SHA1 Message Date
JustSong
6855d0dc39 chore: add x-requested-with header in CORS setting 2023-06-17 15:30:14 +08:00
Kidultx
a43b1e2add feat: support API2GPT platform (#173)
* support API2GPT platform

* chore: update balance renderer

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-17 15:20:51 +08:00
Miniers
46c43396d8 feat: add token name to log (#172)
* add token name to log

* chore: update expression

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-17 14:56:03 +08:00
JustSong
6dcffca065 docs: update README 2023-06-17 12:55:48 +08:00
JustSong
d754620ef7 docs: update README 2023-06-17 12:07:58 +08:00
JustSong
21111126a2 docs: update README 2023-06-17 11:30:31 +08:00
JustSong
d91e7dcfdc docs: update README 2023-06-17 11:24:32 +08:00
JustSong
d79289ccdd fix: fix footer not updated asap 2023-06-17 11:03:01 +08:00
Leslie Leung
f89f6c7fa6 docs: deploy to Zeabur (#170)
* docs: deploy to Zeabur

* docs: deploy to Zeabur

* docs: deploy to Zeabur

* docs: update README

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-17 10:19:13 +08:00
Joe
b7d71b4f0a feat: support update AIProxy balance (#171)
* Add: support update AIProxy balance

* fix auth header

* chore: update balance renderer

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-17 10:08:04 +08:00
JustSong
70ed126ccb feat: return a not found response if requested a wrong API endpoints 2023-06-17 09:46:07 +08:00
JustSong
57b213a035 docs: update README 2023-06-16 16:40:54 +08:00
JustSong
549e944b95 docs: update README 2023-06-16 16:40:07 +08:00
JustSong
0cdab80a6e feat: record channel's used quota (close #137) 2023-06-16 16:02:00 +08:00
JustSong
760183a970 feat: record used quota & request count (close #102, #165) 2023-06-16 15:20:06 +08:00
JustSong
58fb18aace fix: do not record completion ratio anymore 2023-06-16 14:24:16 +08:00
张城铭
630156dc0a fix: the prompt field can be array type now (close #166, #167)
* fix: the prompt field can be array type now (close #166)

* fix: fix prompt type

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-16 14:20:25 +08:00
JustSong
5f23f59d1c docs: update notice 2023-06-16 00:24:38 +08:00
JustSong
538a5d7a9b docs: update issue template (#164) 2023-06-15 21:51:30 +08:00
JustSong
593e1926e9 feat: able to disable quota consumption recording (close #156) 2023-06-15 16:32:16 +08:00
quzard
e87ad1f402 chore: remove -0613 suffix for Azure (#163) 2023-06-14 16:33:03 +08:00
JustSong
07cccdc8c0 docs: update issue template 2023-06-14 15:13:05 +08:00
JustSong
f71f01662c docs: update issue template 2023-06-14 15:03:51 +08:00
JustSong
54d7a1c2e8 docs: update issue template 2023-06-14 15:02:36 +08:00
JustSong
f426f31bd7 docs: update issue template 2023-06-14 14:59:24 +08:00
JustSong
2930577cd6 docs: update issue template 2023-06-14 14:51:48 +08:00
JustSong
e09512177a docs: add issue templates 2023-06-14 14:48:31 +08:00
JustSong
d6dbaff3c2 fix: fix file not committed 2023-06-14 12:52:56 +08:00
JustSong
7f9577a386 feat: now one channel can belong to multiple groups (close #153) 2023-06-14 12:14:08 +08:00
JustSong
38668e7331 chore: update gpt3.5 completion ratio 2023-06-14 09:41:06 +08:00
JustSong
323f3d263a feat: add new released models 2023-06-14 09:12:14 +08:00
JustSong
0c34ed4c61 docs: update README 2023-06-13 17:45:01 +08:00
JustSong
7c7eb6b7ec fix: now the input field can be array type now (close #149) 2023-06-12 16:11:57 +08:00
JustSong
8b2ef666ef fix: fix OpenAI-SB balance not correct 2023-06-12 09:40:49 +08:00
JustSong
955d5f8707 fix: fix group list not correct (close #147) 2023-06-12 09:11:48 +08:00
quzard
47ca449e32 feat: add support for updating balance of channel typpe OpenAI-SB (#146, close #125)
* Add support for updating channel balance in OpenAISB

* fix: handel error

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-06-11 21:04:41 +08:00
JustSong
39481eb6c0 chore: add trailing slash for API calling 2023-06-11 16:33:40 +08:00
JustSong
69153e7231 docs: update README 2023-06-11 12:37:15 +08:00
JustSong
cdef10cad8 docs: update README 2023-06-11 11:11:47 +08:00
JustSong
077853416d chore: record ratio detail in log 2023-06-11 11:11:19 +08:00
JustSong
596446dba4 feat: able to set group ratio now (close #62, close #142) 2023-06-11 11:08:16 +08:00
JustSong
9d0bec83df chore: update prompt 2023-06-11 09:55:50 +08:00
JustSong
f97a9ce597 fix: correct OpenAI error code's type 2023-06-11 09:49:57 +08:00
JustSong
4339f45f74 feat: support /v1/moderations now (close #117) 2023-06-11 09:37:36 +08:00
JustSong
e398e0756b docs: update README 2023-06-10 20:43:32 +08:00
JustSong
64db39320a feat: now able to check all user's log 2023-06-10 20:40:23 +08:00
JustSong
0b4bf30908 docs: update README 2023-06-10 16:34:14 +08:00
JustSong
d29c273073 chore: add more log types 2023-06-10 16:31:40 +08:00
JustSong
74f508e847 feat: now user can check its topup & consume history (close #78, close #95) 2023-06-10 16:04:04 +08:00
JustSong
145bb14cb2 fix: fix not using proxy when update balance 2023-06-09 18:57:27 +08:00
JustSong
8901f03864 feat: support set proxy for channel OpenAI (close #139) 2023-06-09 18:30:01 +08:00
JustSong
813bf0bd66 refactor: enable model configuration on default group (close #143) 2023-06-09 18:05:51 +08:00
JustSong
45e9fd66e7 feat: able to check topup history & consumption history (#78, #95) 2023-06-09 16:59:00 +08:00
JustSong
e0d0674f81 fix: fix redemption code's quota not updated 2023-06-08 15:19:55 +08:00
JustSong
4b6adaec0b feat: support /v1/completions (close #115) 2023-06-08 14:54:02 +08:00
JustSong
9301b3fed3 chore: update test logic 2023-06-08 14:09:39 +08:00
JustSong
c6edb78ac9 docs: update README 2023-06-08 09:44:47 +08:00
JustSong
521ede2469 fix: able to manage root user now 2023-06-08 09:28:06 +08:00
JustSong
2c53424db8 feat: able to manage group now 2023-06-08 09:26:54 +08:00
JustSong
2ad22e1425 feat: support group now (close #17, close #72, close #85, close #104, close #136)
Co-authored-by: quzard <1191890118@qq.com>
2023-06-07 23:26:00 +08:00
JustSong
502515bbbd docs: update README 2023-06-07 15:31:18 +08:00
JustSong
1e1c6a828f fix: prompt user the feat is not implemented (#125) 2023-06-03 11:09:14 +08:00
JustSong
2847a08852 feat: the format of key is now constant with that of OpenAI 2023-06-03 10:53:25 +08:00
JustSong
98f1a627f0 fix: fix balance query (close #138) 2023-06-03 10:17:52 +08:00
JustSong
333e4216d2 fix: fix balance query (close #138) 2023-06-03 10:11:59 +08:00
zaunist
7e80e2da3a fix: add a blank VERSION file (#135) 2023-06-02 14:20:40 +08:00
dependabot[bot]
139624b8a4 chore: bump github.com/gin-gonic/gin from 1.9.0 to 1.9.1 (#134)
Bumps [github.com/gin-gonic/gin](https://github.com/gin-gonic/gin) from 1.9.0 to 1.9.1.
- [Release notes](https://github.com/gin-gonic/gin/releases)
- [Changelog](https://github.com/gin-gonic/gin/blob/master/CHANGELOG.md)
- [Commits](https://github.com/gin-gonic/gin/compare/v1.9.0...v1.9.1)

---
updated-dependencies:
- dependency-name: github.com/gin-gonic/gin
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-02 13:13:36 +08:00
JustSong
2f44aaa645 chore: update config prompt (close #133) 2023-06-01 18:15:53 +08:00
quzard
0f6958c57a fix: update common.RootUserEmail when root's email changed (#132) 2023-05-31 15:51:31 +08:00
quzard
5f045f8cf5 fix: update docker-compose.yml (#131)
使用相对路径,更通用
2023-05-31 15:37:59 +08:00
quzard
f19ee05351 fix: fetch root user's email if blank (#129) 2023-05-31 15:24:40 +08:00
JustSong
fa71daa8a7 fix: fix wrong implementation for /v1/models (close #128) 2023-05-31 14:43:29 +08:00
JustSong
54215dc303 chore: make channel test related code separated 2023-05-23 10:01:09 +08:00
JustSong
f9f42997b2 chore: only check OpenAI channel & custom channel 2023-05-23 10:00:36 +08:00
JustSong
25eab0b224 style: fix UI related problems 2023-05-22 22:41:39 +08:00
JustSong
34bce5b464 style: add positive attribute to submit buttons (close #113) 2023-05-22 22:30:11 +08:00
JustSong
d4794fc051 feat: return user's quota with billing api (close #92) 2023-05-22 17:10:31 +08:00
JustSong
8b43e0dd3f fix: add no-cache for index.html 2023-05-22 00:54:53 +08:00
JustSong
92c88fa273 fix: remove no-store for index.html 2023-05-22 00:44:27 +08:00
JustSong
38191d55be fix: do not cache index.html 2023-05-22 00:39:24 +08:00
JustSong
d9e39f5906 fix: disable channel with a whitelist 2023-05-21 20:58:00 +08:00
JustSong
17b7646c12 fix: fix unable to update custom channel's balance 2023-05-21 20:10:06 +08:00
JustSong
171b818504 feat: support channel remain quota query (close #79) 2023-05-21 16:09:54 +08:00
JustSong
bcca0cc0bc feat: PaLM support is WIP (#105) 2023-05-21 14:26:59 +08:00
JustSong
b92ec5e54c fix: show bind options only available (close #65) 2023-05-21 11:22:28 +08:00
JustSong
fa79e8b7a3 fix: use gpt-3.5's encoder if not found (close #110) 2023-05-21 11:11:19 +08:00
JustSong
1cc7c20183 chore: prompt user if redemption code not input 2023-05-21 10:32:47 +08:00
JustSong
2eee97e9b6 style: add comma to quota stat 2023-05-21 10:15:30 +08:00
JustSong
a3a1b612b0 chore: set initial quota for root user 2023-05-21 10:05:34 +08:00
JustSong
61e682ca47 feat: able to manage user's quota now 2023-05-21 10:01:02 +08:00
JustSong
b383983106 docs: update README 2023-05-21 09:18:23 +08:00
JustSong
cfd587117e feat: support channel AI Proxy now 2023-05-20 17:24:56 +08:00
JustSong
ef9dca28f5 chore: set default value for Azure's api version if not set 2023-05-19 22:13:29 +08:00
JustSong
741c0b9c18 docs: update README (#103) 2023-05-19 15:58:01 +08:00
JustSong
3711f4a741 feat: support channel ai.ls now (close #99) 2023-05-19 11:07:17 +08:00
quzard
7c6bf3e97b fix: make the token number calculation more accurate (#101)
* Make token calculation more accurate.

* fix: make the token number calculation more accurate

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-05-19 09:41:26 +08:00
JustSong
481ba41fbd docs: update README 2023-05-18 18:13:57 +08:00
JustSong
2779d6629c fix: add X-Accel-Buffering header on SSE response 2023-05-18 17:16:34 +08:00
JustSong
e509899daf docs: update README (close #97) 2023-05-18 16:18:45 +08:00
JustSong
b53cdbaf05 docs: update README 2023-05-18 15:57:40 +08:00
JustSong
ced89398a5 chore: rewrite 429 prompt text (close #96) 2023-05-18 15:27:15 +08:00
JustSong
09c2e3bcec docs: fix typo 2023-05-18 12:50:47 +08:00
JustSong
5cba800fa6 docs: fix typo 2023-05-18 12:50:19 +08:00
JustSong
2d39a135f2 feat: now slave server can sync options with master server (close #88) 2023-05-18 12:48:20 +08:00
JustSong
3c6834a79c feat: support redirecting frontend url now (close #89) 2023-05-18 12:26:18 +08:00
JustSong
6da3410823 fix: fix channel test error checking 2023-05-18 11:41:03 +08:00
JustSong
ceb289cb4d fix: handel error response from server correctly (close #90) 2023-05-18 11:11:15 +08:00
JustSong
6f8cc712b0 docs: update README 2023-05-17 23:26:30 +08:00
JustSong
ad01e1f3b3 fix: fix error log not recorded (close #83) 2023-05-17 20:20:48 +08:00
JustSong
cc1ef2ffd5 fix: fix stream mode checking (#83) 2023-05-17 20:10:09 +08:00
JustSong
7201bd1c97 fix: update api2d's base url (#83) 2023-05-17 18:47:25 +08:00
JustSong
73d5e0f283 feat: support dummy sk- prefix for token (#82) 2023-05-17 17:04:06 +08:00
JustSong
efc744ca35 feat: API /models & /models/:model implemented (close #68) 2023-05-17 10:42:52 +08:00
JustSong
e8da98139f fix: limit the shown text's length (close #80) 2023-05-16 21:33:59 +08:00
63 changed files with 2643 additions and 454 deletions

23
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,23 @@
---
name: 报告问题
about: 使用简练详细的语言描述你遇到的问题
title: ''
labels: bug
assignees: ''
---
**例行检查**
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已确认我已升级到最新版本
+ [ ] 我理解并愿意跟进此 issue协助测试和提供反馈
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,不遵循规则的 issue 可能会被无视或直接关闭
**问题描述**
**复现步骤**
**预期结果**
**相关截图**
如果没有的话,请删除此节。

11
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,11 @@
blank_issues_enabled: false
contact_links:
- name: 项目群聊
url: https://openai.justsong.cn/
about: QQ 群828520184自动审核备注 One API
- name: 赞赏支持
url: https://iamazing.cn/page/reward
about: 请作者喝杯咖啡,以激励作者持续开发
- name: 付费部署或定制功能
url: https://openai.justsong.cn/
about: 加群后联系群主

View File

@@ -0,0 +1,18 @@
---
name: 功能请求
about: 使用简练详细的语言描述希望加入的新功能
title: ''
labels: enhancement
assignees: ''
---
**例行检查**
+ [ ] 我已确认目前没有类似 issue
+ [ ] 我已确认我已升级到最新版本
+ [ ] 我理解并愿意跟进此 issue协助测试和提供反馈
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,不遵循规则的 issue 可能会被无视或直接关闭
**功能描述**
**应用场景**

145
README.md
View File

@@ -29,49 +29,64 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
</p>
<p align="center">
<a href="https://github.com/songquanpeng/one-api/releases">程序下载</a>
·
<a href="https://github.com/songquanpeng/one-api#部署">部署教程</a>
·
<a href="https://github.com/songquanpeng/one-api#使用方法">使用方法</a>
·
<a href="https://github.com/songquanpeng/one-api/issues">意见反馈</a>
·
<a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a>
·
<a href="https://openai.justsong.cn/">在线演示</a>
·
<a href="https://github.com/songquanpeng/one-api#常见问题">常见问题</a>
·
<a href="https://iamazing.cn/page/reward">赞赏支持</a>
</p>
> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)
> **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
## 功能
1. 支持多种 API 访问渠道,欢迎 PR 或提 issue 添加更多渠道:
+ [x] OpenAI 官方通道
+ [x] OpenAI 官方通道(支持配置代理)
+ [x] **Azure OpenAI API**
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [CloseAI](https://console.openai-asia.com)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [API2GPT](http://console.api2gpt.com/m/00002S)
+ [x] [CloseAI](https://console.openai-asia.com/r/2412)
+ [x] [AI.LS](https://ai.ls)
+ [x] [OpenAI Max](https://openaimax.com)
+ [x] [OhMyGPT](https://www.ohmygpt.com)
+ [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
2. 支持通过**负载均衡**的方式访问多个渠道。
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
4. 支持**令牌管理**,设置令牌的过期时间和使用次数
5. 支持**兑换码管理**支持批量生成和导出兑换码,可使用兑换码为令牌进行充值
6. 支持**通道管理**批量创建通道
7. 支持发布公告,设置充值链接,设置新用户初始额度
8. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入
9. 支持通过系统访问令牌访问管理 API
10. 支持用户管理,支持**多种用户登录注册方式**
4. 支持**多机部署**[详见此处](#多机部署)
5. 支持**令牌管理**设置令牌的过期时间和使用次数
6. 支持**兑换码管理**支持批量生成和导出兑换码,可使用兑换码为账户进行充值
7. 支持**通道管理**,批量创建通道
8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
9. 支持渠道**设置模型列表**
10. 支持**查看额度明细**
11. 支持发布公告,设置充值链接,设置新用户初始额度
12. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
13. 支持通过系统访问令牌访问管理 API。
14. 支持 Cloudflare Turnstile 用户校验。
15. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
11. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
16. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署
### 基于 Docker 进行部署
执行`docker run -d --restart always -p 3000:3000 -v /home/ubuntu/data/one-api:/data justsong/one-api`
部署命令`docker run --name one-api -d --restart always -p 3000:3000 -v /home/ubuntu/data/one-api:/data justsong/one-api`
更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
@@ -90,13 +105,10 @@ server{
proxy_set_header X-Forwarded-For $remote_addr;
proxy_cache_bypass $http_upgrade;
proxy_set_header Accept-Encoding gzip;
proxy_buffering off; # 重要:关闭代理缓冲
}
}
```
注意,为了 SSE 正常工作,需要关闭 Nginx 的代理缓冲。
之后使用 Let's Encrypt 的 certbot 配置 HTTPS
```bash
# Ubuntu 安装 certbot
@@ -109,6 +121,8 @@ sudo certbot --nginx
sudo service nginx restart
```
初始账号用户名为 `root`,密码为 `123456`
### 手动部署
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell
@@ -133,6 +147,55 @@ sudo service nginx restart
更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
### 多机部署
1. 所有服务器 `SESSION_SECRET` 设置一样的值。
2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite请自行配置主备数据库同步。
3. 所有从服务器必须设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
4. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
环境变量的具体使用方法详见[此处](#环境变量)。
### 部署第三方服务配合 One API 使用
> 欢迎 PR 添加更多示例。
#### ChatGPT Next Web
项目主页https://github.com/Yidadaa/ChatGPT-Next-Web
```bash
docker run --name chat-next-web -d -p 3001:3000 -e BASE_URL=https://openai.justsong.cn yidadaa/chatgpt-next-web
```
注意修改端口号和 `BASE_URL`。
#### ChatGPT Web
项目主页https://github.com/Chanzhaoyu/chatgpt-web
```bash
docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://openai.justsong.cn -e OPENAI_API_KEY=sk-xxx chenzhaoyu94/chatgpt-web
```
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
### 部署到第三方平台
<details>
<summary><strong>部署到 Zeabur</strong></summary>
<div>
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。
1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com/),登录,进入控制台。
3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口
4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。
5. 然后在 Service -> Add Service选择 Git第一次使用需要先授权选择你 fork 的仓库。
6. Deploy 会自动开始,先取消。进入下方 Variable添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `<username>:<password>@tcp(<addr>:<port>)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。
7. 选择 Redeploy。
8. 进入下方 Domains选择一个合适的域名前缀如 "my-one-api",最终域名为 "my-one-api.zeabur.app",也可以 CNAME 自己的域名。
9. 等待部署完成,点击生成的域名进入 One API。
</div>
</details>
## 配置
系统本身开箱即用。
@@ -140,11 +203,24 @@ sudo service nginx restart
等到系统启动后,使用 `root` 用户登录系统并做进一步的配置。
## 使用方
在`渠道`页面中添加你的 API Key之后在`令牌`页面中新增一个访问令牌。
## 使用方
在`渠道`页面中添加你的 API Key之后在`令牌`页面中新增访问令牌。
之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。
你需要在各种用到 OpenAI API 的地方设置 API Base 为你的 One API 的部署地址,例如:`https://openai.justsong.cn`API Key 则为你在 One API 中生成的令牌。
注意,具体的 API Base 的格式取决于你所使用的客户端。
```mermaid
graph LR
A(用户)
A --->|请求| B(One API)
B -->|中继请求| C(OpenAI)
B -->|中继请求| D(Azure)
B -->|中继请求| E(其他下游渠道)
```
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
注意,需要是管理员用户创建的令牌才能指定渠道 ID。
@@ -155,8 +231,12 @@ sudo service nginx restart
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 8.0 版本
+ 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api`
4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
+ 例子:`SYNC_FREQUENCY=60`
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -174,3 +254,20 @@ https://openai.justsong.cn
### 截图展示
![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png)
![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png)
## 常见问题
1. 账户额度足够为什么提示额度不足?
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。
2. 宝塔部署后访问出现空白页面?
+ 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。
3. 提示无可用渠道?
+ 请检查的用户分组和渠道分组设置。
+ 以及渠道的模型设置。
## 注意
本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及法律法规的情况下使用,不得用于非法用途。
本项目使用 MIT 协议进行开源,请以某种方式保留 One API 的版权信息。
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。

0
VERSION Normal file
View File

View File

@@ -0,0 +1,17 @@
INSERT INTO abilities (`group`, model, channel_id, enabled)
SELECT c.`group`, m.model, c.id, 1
FROM channels c
CROSS JOIN (
SELECT 'gpt-3.5-turbo' AS model UNION ALL
SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
SELECT 'gpt-4' AS model UNION ALL
SELECT 'gpt-4-0314' AS model
) AS m
WHERE c.status = 1
AND NOT EXISTS (
SELECT 1
FROM abilities a
WHERE a.`group` = c.`group`
AND a.model = m.model
AND a.channel_id = c.id
);

View File

@@ -1,9 +1,10 @@
package common
import (
"github.com/google/uuid"
"sync"
"time"
"github.com/google/uuid"
)
var StartTime = time.Now().Unix() // unit: second
@@ -25,6 +26,7 @@ var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
@@ -34,6 +36,8 @@ var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
@@ -127,16 +131,24 @@ const (
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://openai.api2d.net", // 2
"https://oa.api2d.net", // 2
"", // 3
"https://api.openai-asia.com", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
}

26
common/gin.go Normal file
View File

@@ -0,0 +1,26 @@
package common
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
}
err = c.Request.Body.Close()
if err != nil {
return err
}
err = json.Unmarshal(requestBody, &v)
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}

31
common/group-ratio.go Normal file
View File

@@ -0,0 +1,31 @@
package common
import "encoding/json"
var GroupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
SysError("Error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
SysError("Group ratio not found: " + name)
return 1
}
return ratio
}

View File

@@ -2,16 +2,23 @@ package common
import "encoding/json"
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
var ModelRatio = map[string]float64{
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-3.5-turbo": 1,
"gpt-3.5-turbo-0301": 1,
"gpt-4-32k-0613": 30,
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
@@ -26,8 +33,8 @@ var ModelRatio = map[string]float64{
"ada": 10,
"text-embedding-ada-002": 0.2,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 10,
"text-moderation-latest": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
}
func ModelRatio2JSONString() string {
@@ -39,6 +46,7 @@ func ModelRatio2JSONString() string {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
ModelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os/exec"
"runtime"
@@ -133,6 +134,29 @@ func GetUUID() string {
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}

41
controller/billing.go Normal file
View File

@@ -0,0 +1,41 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/model"
)
func GetSubscription(c *gin.Context) {
userId := c.GetInt("id")
quota, err := model.GetUserQuota(userId)
if err != nil {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
}
c.JSON(200, gin.H{
"error": openAIError,
})
return
}
subscription := OpenAISubscriptionResponse{
Object: "billing_subscription",
HasPaymentMethod: true,
SoftLimitUSD: float64(quota),
HardLimitUSD: float64(quota),
SystemHardLimitUSD: float64(quota),
}
c.JSON(200, subscription)
return
}
func GetUsage(c *gin.Context) {
//userId := c.GetInt("id")
// TODO: get usage from database
usage := OpenAIUsageResponse{
Object: "list",
TotalUsage: 0,
}
c.JSON(200, usage)
return
}

View File

@@ -0,0 +1,280 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
"github.com/gin-gonic/gin"
)
// https://github.com/songquanpeng/one-api/issues/79
type OpenAISubscriptionResponse struct {
Object string `json:"object"`
HasPaymentMethod bool `json:"has_payment_method"`
SoftLimitUSD float64 `json:"soft_limit_usd"`
HardLimitUSD float64 `json:"hard_limit_usd"`
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
}
type OpenAIUsageDailyCost struct {
Timestamp float64 `json:"timestamp"`
LineItems []struct {
Name string `json:"name"`
Cost float64 `json:"cost"`
}
}
type OpenAIUsageResponse struct {
Object string `json:"object"`
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
type OpenAISBUsageResponse struct {
Msg string `json:"msg"`
Data *struct {
Credit string `json:"credit"`
} `json:"data"`
}
type AIProxyUserOverviewResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
ErrorCode int `json:"error_code"`
Data struct {
TotalPoints float64 `json:"totalPoints"`
} `json:"data"`
}
type API2GPTUsageResponse struct {
Object string `json:"object"`
TotalGranted float64 `json:"total_granted"`
TotalUsed float64 `json:"total_used"`
TotalRemaining float64 `json:"total_remaining"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
return h
}
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
return nil, err
}
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
err = res.Body.Close()
if err != nil {
return nil, err
}
return body, nil
}
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenAISBUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Data == nil {
return 0, errors.New(response.Msg)
}
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
url := "https://aiproxy.io/api/report/getUserOverview"
headers := http.Header{}
headers.Add("Api-Key", channel.Key)
body, err := GetResponseBody("GET", url, channel, headers)
if err != nil {
return 0, err
}
response := AIProxyUserOverviewResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Success {
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
}
channel.UpdateBalance(response.Data.TotalPoints)
return response.Data.TotalPoints, nil
}
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := API2GPTUsageResponse{}
err = json.Unmarshal(body, &response)
fmt.Print(response)
if err != nil {
return 0, err
}
channel.UpdateBalance(response.TotalRemaining)
return response.TotalRemaining, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
switch channel.Type {
case common.ChannelTypeOpenAI:
if channel.BaseURL != "" {
baseURL = channel.BaseURL
}
case common.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.BaseURL
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel)
default:
return 0, errors.New("尚未实现")
}
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
subscription := OpenAISubscriptionResponse{}
err = json.Unmarshal(body, &subscription)
if err != nil {
return 0, err
}
now := time.Now()
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
endDate := now.Format("2006-01-02")
if !subscription.HasPaymentMethod {
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
}
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
usage := OpenAIUsageResponse{}
err = json.Unmarshal(body, &usage)
if err != nil {
return 0, err
}
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)
return balance, nil
}
func UpdateChannelBalance(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
balance, err := updateChannelBalance(channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"balance": balance,
})
return
}
func updateAllChannelsBalance() error {
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
return err
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
// TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
continue
}
balance, err := updateChannelBalance(channel)
if err != nil {
continue
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
}
}
}
return nil
}
func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async
err := updateAllChannelsBalance()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

204
controller/channel-test.go Normal file
View File

@@ -0,0 +1,204 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"sync"
"time"
)
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
} else if channel.Type == common.ChannelTypeOpenAI && channel.BaseURL != "" {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err.Error())
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@@ -1,18 +1,12 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
)
func GetAllChannels(c *gin.Context) {
@@ -158,187 +152,3 @@ func UpdateChannel(c *gin.Context) {
})
return
}
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Error.Type != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
MaxTokens: 1,
}
testMessage := Message{
Role: "user",
Content: "hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
testRequest := buildTestRequest(c)
tik := time.Now()
err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, err error) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, err.Error())
err = common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("测试已在运行中")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
disableChannel(channel.Id, channel.Name, err)
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

19
controller/group.go Normal file
View File

@@ -0,0 +1,19 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": groupNames,
})
}

86
controller/log.go Normal file
View File

@@ -0,0 +1,86 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"strconv"
)
func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
logType, _ := strconv.Atoi(c.Query("type"))
logs, err := model.GetAllLogs(logType, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
logType, _ := strconv.Atoi(c.Query("type"))
logs, err := model.GetUserLogs(userId, logType, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func SearchUserLogs(c *gin.Context) {
keyword := c.Query("keyword")
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}

256
controller/model.go Normal file
View File

@@ -0,0 +1,256 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
}
func ListModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": openAIError,
})
}
}

59
controller/relay-palm.go Normal file
View File

@@ -0,0 +1,59 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
)
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
type PaLMChatRequest struct {
Prompt []Message `json:"prompt"`
Temperature float64 `json:"temperature"`
CandidateCount int `json:"candidateCount"`
TopP float64 `json:"topP"`
TopK int `json:"topK"`
}
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
type PaLMChatResponse struct {
Candidates []Message `json:"candidates"`
Messages []Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
}
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
for _, message := range openAIRequest.Messages {
var author string
if message.Role == "user" {
author = "0"
} else {
author = "1"
}
messages = append(messages, PaLMChatMessage{
Author: author,
Content: message.Content,
})
}
request := PaLMChatRequest{
Prompt: nil,
Temperature: openAIRequest.Temperature,
CandidateCount: openAIRequest.N,
TopP: openAIRequest.TopP,
TopK: openAIRequest.MaxTokens,
}
// TODO: forward request to PaLM & convert response
fmt.Print(request)
return nil
}

79
controller/relay-utils.go Normal file
View File

@@ -0,0 +1,79 @@
package controller
import (
"fmt"
"github.com/pkoukk/tiktoken-go"
"one-api/common"
"strings"
)
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
return tokenEncoder
}
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
}
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if strings.HasPrefix(model, "gpt-3.5") {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else if strings.HasPrefix(model, "gpt-4") {
tokensPerMessage = 3
tokensPerName = 1
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
func countTokenInput(input any, model string) int {
switch input.(type) {
case string:
return countTokenText(input.(string), model)
case []string:
text := ""
for _, s := range input.([]string) {
text += s
}
return countTokenText(text, model)
}
return 0
}
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}

View File

@@ -4,10 +4,8 @@ import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io"
"net/http"
"one-api/common"
@@ -16,8 +14,31 @@ import (
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
Name *string `json:"name,omitempty"`
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModeration
)
// https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt any `json:"prompt"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
N int `json:"n"`
Input any `json:"input"`
}
type ChatRequest struct {
@@ -44,7 +65,12 @@ type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code string `json:"code"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
@@ -52,7 +78,7 @@ type TextResponse struct {
Error OpenAIError `json:"error"`
}
type StreamResponse struct {
type ChatCompletionsStreamResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
@@ -61,55 +87,78 @@ type StreamResponse struct {
} `json:"choices"`
}
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
func countToken(text string) int {
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
func Relay(c *gin.Context) {
err := relayHelper(c)
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModeration
}
err := relayHelper(c, relayMode)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
if common.AutomaticDisableChannelEnabled {
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err)
disableChannel(channelId, channelName, err.Message)
}
}
}
func relayHelper(c *gin.Context) error {
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id")
consumeQuota := c.GetBool("consume_quota")
var textRequest TextRequest
if consumeQuota || channelType == common.ChannelTypeAzure {
requestBody, err := io.ReadAll(c.Request.Body)
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return err
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
err = c.Request.Body.Close()
if err != nil {
return err
}
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
if relayMode == RelayModeModeration && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if channelType == common.ChannelTypeCustom {
baseURL = c.GetString("base_url")
} else if channelType == common.ChannelTypeOpenAI {
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeAzure {
@@ -128,28 +177,38 @@ func relayHelper(c *gin.Context) error {
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
} else if channelType == common.ChannelTypePaLM {
err := relayPaLM(textRequest, c)
return err
}
var promptText string
for _, message := range textRequest.Messages {
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
var promptTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModeration:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
promptTokens := countToken(promptText) + 3
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
ratio := common.GetModelRatio(textRequest.Model)
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
if consumeQuota {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return err
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
if err != nil {
return err
return errorWrapper(err, "new_request_failed", http.StatusOK)
}
if channelType == common.ChannelTypeAzure {
key := c.Request.Header.Get("Authorization")
@@ -164,40 +223,48 @@ func relayHelper(c *gin.Context) error {
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
return errorWrapper(err, "do_request_failed", http.StatusOK)
}
err = req.Body.Close()
if err != nil {
return err
return errorWrapper(err, "close_request_body_failed", http.StatusOK)
}
err = c.Request.Body.Close()
if err != nil {
return err
return errorWrapper(err, "close_request_body_failed", http.StatusOK)
}
var textResponse TextResponse
isStream := resp.Header.Get("Content-Type") == "text/event-stream"
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
var streamResponseText string
defer func() {
if consumeQuota {
quota := 0
usingGPT4 := strings.HasPrefix(textRequest.Model, "gpt-4")
completionRatio := 1
if usingGPT4 {
completionRatio := 1.34 // default for gpt-3
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if isStream {
completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
quota = promptTokens + countToken(completionText)*completionRatio
responseTokens := countTokenText(streamResponseText, textRequest.Model)
quota = promptTokens + int(float64(responseTokens)*completionRatio)
} else {
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("Error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
userId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %d 点额度(模型倍率 %.2f,分组倍率 %.2f", tokenName, textRequest.Model, quota, modelRatio, groupRatio))
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
@@ -223,17 +290,34 @@ func relayHelper(c *gin.Context) error {
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // must be something wrong!
common.SysError("Invalid stream response: " + data)
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse StreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
switch relayMode {
case RelayModeChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
case RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("Error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Text
}
}
}
}
@@ -243,6 +327,7 @@ func relayHelper(c *gin.Context) error {
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -257,50 +342,72 @@ func relayHelper(c *gin.Context) error {
})
err = resp.Body.Close()
if err != nil {
return err
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
}
return nil
} else {
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
return errorWrapper(err, "read_response_body_failed", http.StatusOK)
}
err = resp.Body.Close()
if err != nil {
return err
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return err
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
}
if textResponse.Error.Type != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s",
textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the client will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return err
return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
}
err = resp.Body.Close()
if err != nil {
return err
return errorWrapper(err, "close_response_body_failed", http.StatusOK)
}
return nil
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
Code: "api_not_implemented",
}
c.JSON(http.StatusOK, gin.H{
"error": gin.H{
"message": "Not Implemented",
"type": "one_api_error",
},
"error": err,
})
}
func RelayNotFound(c *gin.Context) {
err := OpenAIError{
Message: fmt.Sprintf("API not found: %s:%s", c.Request.Method, c.Request.URL.Path),
Type: "one_api_error",
Param: "",
Code: "api_not_found",
}
c.JSON(http.StatusOK, gin.H{
"error": err,
})
}

View File

@@ -119,7 +119,7 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GetUUID(),
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,

View File

@@ -2,6 +2,7 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
@@ -228,7 +229,7 @@ func GetUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
if myRole <= user.Role {
if myRole <= user.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权获取同级或更高等级用户的信息",
@@ -326,14 +327,14 @@ func UpdateUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
if myRole <= originUser.Role {
if myRole <= originUser.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
})
return
}
if myRole <= updatedUser.Role {
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@@ -351,6 +352,9 @@ func UpdateUser(c *gin.Context) {
})
return
}
if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %d 点修改为 %d 点", originUser.Quota, updatedUser.Quota))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -467,6 +471,13 @@ func CreateUser(c *gin.Context) {
})
return
}
if err := common.Validate.Struct(&user); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "输入不合法 " + err.Error(),
})
return
}
if user.DisplayName == "" {
user.DisplayName = user.Username
}
@@ -648,6 +659,9 @@ func EmailBind(c *gin.Context) {
})
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

View File

@@ -9,8 +9,8 @@ services:
ports:
- "3000:3000"
volumes:
- /home/ubuntu/data/one-api:/data
- /home/ubuntu/data/one-api/logs:/app/logs
- ./data:/data
- ./logs:/app/logs
# environment:
# REDIS_CONN_STRING: redis://default:redispw@localhost:49153
# SESSION_SECRET: random_string

19
go.mod
View File

@@ -8,12 +8,12 @@ require (
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.0
github.com/go-playground/validator/v10 v10.12.0
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.3.0
github.com/pkoukk/tiktoken-go v0.1.1
golang.org/x/crypto v0.8.0
golang.org/x/crypto v0.9.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.24.0
@@ -21,11 +21,12 @@ require (
require (
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
github.com/bytedance/sonic v1.8.8 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
@@ -39,17 +40,17 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.3 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.7 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.9.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

41
go.sum
View File

@@ -1,8 +1,8 @@
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04=
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.8.8 h1:Kj4AYbZSeENfyXicsYppYKO0K2YWab+i2UTSY7Ukz9Q=
github.com/bytedance/sonic v1.8.8/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
@@ -17,6 +17,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
@@ -29,8 +31,8 @@ github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Sw
github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs=
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8=
github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
@@ -43,8 +45,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.12.0 h1:E4gtWgxWxp8YSxExrQFv5BpCahla0PVF2oTTEYaWQGI=
github.com/go-playground/validator/v10 v10.12.0/go.mod h1:hCAPuzYvKdP33pxWa+2+6AIKXEKqjIUyqsNCtbsSJrA=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
@@ -89,12 +91,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.2.3 h1:6BE2vPT0lqoz3fmOesHZiaiFh7889ssCo2GMvLCfiuA=
github.com/leodido/go-urn v1.2.3/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
@@ -108,8 +110,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.7 h1:muncTPStnKRos5dpVKULv2FVd4bMOhNePj9CjgDb8Us=
github.com/pelletier/go-toml/v2 v2.0.7/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
@@ -128,8 +130,9 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -142,11 +145,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -154,8 +157,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View File

@@ -47,6 +47,13 @@ func main() {
// Initialize options
model.InitOptionMap()
if os.Getenv("SYNC_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
if err != nil {
common.FatalLog(err)
}
go model.SyncOptions(frequency)
}
// Initialize HTTP server
server := gin.Default()

View File

@@ -85,6 +85,8 @@ func RootAuth() func(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.ValidateUserToken(key)
@@ -110,6 +112,7 @@ func TokenAuth() func(c *gin.Context) {
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {

View File

@@ -6,7 +6,11 @@ import (
func Cache() func(c *gin.Context) {
return func(c *gin.Context) {
c.Header("Cache-Control", "max-age=604800") // one week
if c.Request.RequestURI == "/" {
c.Header("Cache-Control", "no-cache")
} else {
c.Header("Cache-Control", "max-age=604800") // one week
}
c.Next()
}
}

View File

@@ -10,6 +10,6 @@ func CORS() gin.HandlerFunc {
config.AllowAllOrigins = true
config.AllowCredentials = true
config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
config.AllowHeaders = []string{"Origin", "Content-Length", "Content-Type", "Authorization", "Accept", "Connection"}
config.AllowHeaders = []string{"Origin", "Content-Length", "Content-Type", "Authorization", "Accept", "Connection", "x-requested-with"}
return cors.New(config)
}

View File

@@ -7,10 +7,18 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
type ModelRequest struct {
Model string `json:"model"`
}
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.GetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
channelId, ok := c.Get("channelId")
if ok {
@@ -48,8 +56,24 @@ func Distribute() func(c *gin.Context) {
}
} else {
// Select a channel for the user
var err error
channel, err = model.GetRandomChannel()
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
c.JSON(200, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
c.JSON(200, gin.H{
"error": gin.H{
@@ -65,11 +89,9 @@ func Distribute() func(c *gin.Context) {
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure {
c.Set("api_version", channel.Other)
}
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure {
c.Set("api_version", channel.Other)
}
c.Next()
}

72
model/ability.go Normal file
View File

@@ -0,0 +1,72 @@
package model
import (
"one-api/common"
"strings"
)
type Ability struct {
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err
}
channel := Channel{}
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
return &channel, err
}
func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
ability := Ability{
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
}
abilities = append(abilities, ability)
}
}
return DB.Create(&abilities).Error
}
func (channel *Channel) DeleteAbilities() error {
return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
}
// UpdateAbilities updates abilities of this channel.
// Make sure the channel is completed before calling this function.
func (channel *Channel) UpdateAbilities() error {
// A quick and dirty way to update abilities
// First delete all abilities of this channel
err := channel.DeleteAbilities()
if err != nil {
return err
}
// Then add new abilities
err = channel.AddAbilities()
if err != nil {
return err
}
return nil
}
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}

View File

@@ -1,22 +1,27 @@
package model
import (
_ "gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight int `json:"weight"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"`
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight int `json:"weight"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -47,13 +52,12 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
}
func GetRandomChannel() (*Channel, error) {
// TODO: consider weight
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
@@ -61,18 +65,36 @@ func GetRandomChannel() (*Channel, error) {
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error
return err
if err != nil {
return err
}
for _, channel_ := range channels {
err = channel_.AddAbilities()
if err != nil {
return err
}
}
return nil
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
if err != nil {
return err
}
err = channel.AddAbilities()
return err
}
func (channel *Channel) Update() error {
var err error
err = DB.Model(channel).Updates(channel).Error
if err != nil {
return err
}
DB.Model(channel).First(channel, "id = ?", channel.Id)
err = channel.UpdateAbilities()
return err
}
@@ -86,15 +108,40 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
}
}
func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(),
Balance: balance,
}).Error
if err != nil {
common.SysError("failed to update balance: " + err.Error())
}
}
func (channel *Channel) Delete() error {
var err error
err = DB.Delete(channel).Error
if err != nil {
return err
}
err = channel.DeleteAbilities()
return err
}
func UpdateChannelStatusById(id int, status int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
}
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
common.SysError("failed to update channel used quota: " + err.Error())
}
}

70
model/log.go Normal file
View File

@@ -0,0 +1,70 @@
package model
import (
"gorm.io/gorm"
"one-api/common"
)
type Log struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint"`
Type int `json:"type" gorm:"index"`
Content string `json:"content"`
}
const (
LogTypeUnknown = iota
LogTypeTopup
LogTypeConsume
LogTypeManage
LogTypeSystem
)
func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
CreatedAt: common.GetTimestamp(),
Type: logType,
Content: content,
}
err := DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
}
}
func GetAllLogs(logType int, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
} else {
tx = DB.Where("type = ?", logType)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
func GetUserLogs(userId int, logType int, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId)
} else {
tx = DB.Where("user_id = ? and type = ?", userId, logType)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
return logs, err
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}

View File

@@ -26,6 +26,7 @@ func createRootAccountIfNeed() error {
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: common.GetUUID(),
Quota: 100000000,
}
DB.Create(&rootUser)
}
@@ -74,6 +75,14 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
}
err = createRootAccountIfNeed()
return err
} else {

View File

@@ -4,6 +4,7 @@ import (
"one-api/common"
"strconv"
"strings"
"time"
)
type Option struct {
@@ -33,6 +34,7 @@ func InitOptionMap() {
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
@@ -57,8 +59,13 @@ func InitOptionMap() {
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
@@ -68,6 +75,14 @@ func InitOptionMap() {
}
}
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("Syncing options from database")
loadOptionsFromDatabase()
}
}
func UpdateOption(key string, value string) error {
// Save to database first
option := Option{
@@ -120,6 +135,8 @@ func updateOptionMap(key string, value string) (err error) {
common.RegisterEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
}
}
switch key {
@@ -164,6 +181,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChannelDisableThreshold":

View File

@@ -2,7 +2,7 @@ package model
import (
"errors"
_ "gorm.io/driver/sqlite"
"fmt"
"one-api/common"
)
@@ -66,6 +66,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
common.SysError("更新兑换码状态失败:" + err.Error())
}
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %d 点额度", redemption.Quota))
}()
return redemption.Quota, nil
}
@@ -84,7 +85,7 @@ func (redemption *Redemption) SelectUpdate() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (redemption *Redemption) Update() error {
var err error
err = DB.Model(redemption).Select("name", "status", "redeemed_time").Updates(redemption).Error
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
return err
}

View File

@@ -3,16 +3,14 @@ package model
import (
"errors"
"fmt"
_ "gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
@@ -38,7 +36,6 @@ func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供 token")
}
key = strings.Replace(key, "Bearer ", "", 1)
token = &Token{}
err = DB.Where("`key` = ?", key).First(token).Error
if err == nil {

View File

@@ -2,6 +2,7 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strings"
@@ -19,10 +20,12 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
Balance int `json:"balance" gorm:"type:int;default:0"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
}
func GetMaxUserId() int {
@@ -73,8 +76,14 @@ func (user *User) Insert() error {
}
user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID()
err = DB.Create(user).Error
return err
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %d 点额度", common.QuotaForNewUser))
}
return nil
}
func (user *User) Update(updatePassword bool) error {
@@ -230,6 +239,11 @@ func GetUserEmail(id int) (email string, err error) {
return email, err
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
return group, err
}
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
@@ -250,3 +264,15 @@ func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", 1),
},
).Error
if err != nil {
common.SysError("Failed to update user used quota and request count: " + err.Error())
}
}

View File

@@ -63,9 +63,12 @@ func SetApiRouter(router *gin.Engine) {
{
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/models", controller.ListModels)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
@@ -90,5 +93,15 @@ func SetApiRouter(router *gin.Engine) {
redemptionRoute.PUT("/", controller.UpdateRedemption)
redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
}
logRoute := apiRouter.Group("/log")
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
}
}

View File

@@ -8,11 +8,14 @@ import (
)
func SetDashboardRouter(router *gin.Engine) {
apiRouter := router.Group("/dashboard")
apiRouter := router.Group("/")
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
apiRouter.Use(middleware.TokenAuth())
{
apiRouter.GET("/billing/credit_grants", controller.GetTokenStatus)
apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
}
}

View File

@@ -2,12 +2,24 @@ package router
import (
"embed"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"os"
"strings"
)
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)
setWebRouter(router, buildFS, indexPage)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS, indexPage)
} else {
frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
router.NoRoute(func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
})
}
}

View File

@@ -8,12 +8,16 @@ import (
func SetRelayRouter(router *gin.Engine) {
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())
{
modelsRouter.GET("/", controller.ListModels)
modelsRouter.GET("/:model", controller.RetrieveModel)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.GET("/models", controller.Relay)
relayV1Router.GET("/models/:model", controller.Relay)
relayV1Router.POST("/completions", controller.RelayNotImplemented)
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
@@ -33,6 +37,6 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay)
}
}

View File

@@ -7,15 +7,22 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/controller"
"one-api/middleware"
"strings"
)
func setWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
router.Use(gzip.Gzip(gzip.DefaultCompression))
router.Use(middleware.GlobalWebRateLimit())
router.Use(middleware.Cache())
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
router.NoRoute(func(c *gin.Context) {
if strings.HasPrefix(c.Request.RequestURI, "/v1") {
controller.RelayNotFound(c)
return
}
c.Header("Cache-Control", "no-cache")
c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
})
}

View File

@@ -22,6 +22,7 @@ import EditChannel from './pages/Channel/EditChannel';
import Redemption from './pages/Redemption';
import EditRedemption from './pages/Redemption/EditRedemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
@@ -250,6 +251,14 @@ function App() {
</PrivateRoute>
}
/>
<Route
path='/log'
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path='/about'
element={

View File

@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render';
function renderTimestamp(timestamp) {
return (
@@ -26,12 +27,29 @@ function renderType(type) {
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
}
function renderBalance(type, balance) {
switch (type) {
case 1: // OpenAI
case 8: // 自定义
return <span>${balance.toFixed(2)}</span>;
case 5: // OpenAI-SB
return <span>¥{(balance / 10000).toFixed(2)}</span>;
case 10: // AI Proxy
return <span>{renderNumber(balance)}</span>;
case 12: // API2GPT
return <span>¥{balance.toFixed(2)}</span>;
default:
return <span>不支持</span>;
}
}
const ChannelsTable = () => {
const [channels, setChannels] = useState([]);
const [loading, setLoading] = useState(true);
const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [updatingBalance, setUpdatingBalance] = useState(false);
const loadChannels = async (startIdx) => {
const res = await API.get(`/api/channel/?p=${startIdx}`);
@@ -63,7 +81,7 @@ const ChannelsTable = () => {
const refresh = async () => {
setLoading(true);
await loadChannels(0);
}
};
useEffect(() => {
loadChannels(0)
@@ -127,7 +145,7 @@ const ChannelsTable = () => {
const renderResponseTime = (responseTime) => {
let time = responseTime / 1000;
time = time.toFixed(2) + "";
time = time.toFixed(2) + '';
if (responseTime === 0) {
return <Label basic color='grey'>未测试</Label>;
} else if (responseTime <= 1000) {
@@ -179,11 +197,38 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。");
showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
} else {
showError(message);
}
}
};
const updateChannelBalance = async (id, name, idx) => {
const res = await API.get(`/api/channel/update_balance/${id}/`);
const { success, message, balance } = res.data;
if (success) {
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].balance = balance;
newChannels[realIdx].balance_updated_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`通道 ${name} 余额更新成功!`);
} else {
showError(message);
}
};
const updateAllChannelsBalance = async () => {
setUpdatingBalance(true);
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用通道余额!');
} else {
showError(message);
}
setUpdatingBalance(false);
};
const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim());
@@ -236,6 +281,14 @@ const ChannelsTable = () => {
>
名称
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('group');
}}
>
分组
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
@@ -263,10 +316,10 @@ const ChannelsTable = () => {
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('test_time');
sortChannel('balance');
}}
>
测试时间
余额
</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
@@ -284,10 +337,25 @@ const ChannelsTable = () => {
<Table.Row key={channel.id}>
<Table.Cell>{channel.id}</Table.Cell>
<Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
<Table.Cell>{renderGroup(channel.group)}</Table.Cell>
<Table.Cell>{renderType(channel.type)}</Table.Cell>
<Table.Cell>{renderStatus(channel.status)}</Table.Cell>
<Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell>
<Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell>
<Table.Cell>
<Popup
content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
key={channel.id}
trigger={renderResponseTime(channel.response_time)}
basic
/>
</Table.Cell>
<Table.Cell>
<Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
key={channel.id}
trigger={renderBalance(channel.type, channel.balance)}
basic
/>
</Table.Cell>
<Table.Cell>
<div>
<Button
@@ -299,6 +367,16 @@ const ChannelsTable = () => {
>
测试
</Button>
<Button
size={'small'}
positive
loading={updatingBalance}
onClick={() => {
updateChannelBalance(channel.id, channel.name, idx);
}}
>
更新余额
</Button>
<Popup
trigger={
<Button size='small' negative>
@@ -346,13 +424,15 @@ const ChannelsTable = () => {
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan='7'>
<Table.HeaderCell colSpan='8'>
<Button size='small' as={Link} to='/channel/add' loading={loading}>
添加新的渠道
</Button>
<Button size='small' loading={loading} onClick={testAllChannels}>
测试所有已启用通道
</Button>
<Button size='small' onClick={updateAllChannelsBalance}
loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
<Pagination
floated='right'
activePage={activePage}

View File

@@ -1,11 +1,31 @@
import React from 'react';
import React, { useEffect, useState } from 'react';
import { Container, Segment } from 'semantic-ui-react';
import { getFooterHTML, getSystemName } from '../helpers';
const Footer = () => {
const systemName = getSystemName();
const footer = getFooterHTML();
const [footer, setFooter] = useState(getFooterHTML());
let remainCheckTimes = 5;
const loadFooter = () => {
let footer_html = localStorage.getItem('footer_html');
if (footer_html) {
setFooter(footer_html);
}
};
useEffect(() => {
const timer = setInterval(() => {
if (remainCheckTimes <= 0) {
clearInterval(timer);
return;
}
remainCheckTimes--;
loadFooter();
}, 200);
return () => clearTimeout(timer);
}, []);
return (
<Segment vertical>

View File

@@ -41,6 +41,11 @@ const headerButtons = [
icon: 'user',
admin: true,
},
{
name: '日志',
to: '/log',
icon: 'book',
},
{
name: '设置',
to: '/setting',

View File

@@ -0,0 +1,256 @@
import React, { useEffect, useState } from 'react';
import { Button, Label, Pagination, Select, Table } from 'semantic-ui-react';
import { API, isAdmin, showError, timestamp2string } from '../helpers';
import { ITEMS_PER_PAGE } from '../constants';
function renderTimestamp(timestamp) {
return (
<>
{timestamp2string(timestamp)}
</>
);
}
const MODE_OPTIONS = [
{ key: 'all', text: '全部用户', value: 'all' },
{ key: 'self', text: '当前用户', value: 'self' },
];
const LOG_OPTIONS = [
{ key: '0', text: '全部', value: 0 },
{ key: '1', text: '充值', value: 1 },
{ key: '2', text: '消费', value: 2 },
{ key: '3', text: '管理', value: 3 },
{ key: '4', text: '系统', value: 4 }
];
function renderType(type) {
switch (type) {
case 1:
return <Label basic color='green'> 充值 </Label>;
case 2:
return <Label basic color='olive'> 消费 </Label>;
case 3:
return <Label basic color='orange'> 管理 </Label>;
case 4:
return <Label basic color='purple'> 系统 </Label>;
default:
return <Label basic color='black'> 未知 </Label>;
}
}
const LogsTable = () => {
const [logs, setLogs] = useState([]);
const [loading, setLoading] = useState(true);
const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [logType, setLogType] = useState(0);
const [mode, setMode] = useState('self'); // all, self
const showModePanel = isAdmin();
const loadLogs = async (startIdx) => {
let url = `/api/log/self/?p=${startIdx}&type=${logType}`;
if (mode === 'all') {
url = `/api/log/?p=${startIdx}&type=${logType}`;
}
const res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setLogs(data);
} else {
let newLogs = logs;
newLogs.push(...data);
setLogs(newLogs);
}
} else {
showError(message);
}
setLoading(false);
};
const onPaginationChange = (e, { activePage }) => {
(async () => {
if (activePage === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadLogs(activePage - 1);
}
setActivePage(activePage);
})();
};
const refresh = async () => {
setLoading(true);
await loadLogs(0);
};
useEffect(() => {
loadLogs(0)
.then()
.catch((reason) => {
showError(reason);
});
}, []);
useEffect(() => {
refresh().then();
}, [mode, logType]);
const searchLogs = async () => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
await loadLogs(0);
setActivePage(1);
return;
}
setSearching(true);
const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`);
const { success, message, data } = res.data;
if (success) {
setLogs(data);
setActivePage(1);
} else {
showError(message);
}
setSearching(false);
};
const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim());
};
const sortLog = (key) => {
if (logs.length === 0) return;
setLoading(true);
let sortedLogs = [...logs];
sortedLogs.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]);
});
if (sortedLogs[0].id === logs[0].id) {
sortedLogs.reverse();
}
setLogs(sortedLogs);
setLoading(false);
};
return (
<>
<Table basic>
<Table.Header>
<Table.Row>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('created_time');
}}
width={3}
>
时间
</Table.HeaderCell>
{
showModePanel && (
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('user_id');
}}
width={1}
>
用户
</Table.HeaderCell>
)
}
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('type');
}}
width={2}
>
类型
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('content');
}}
width={showModePanel ? 10 : 11}
>
详情
</Table.HeaderCell>
</Table.Row>
</Table.Header>
<Table.Body>
{logs
.slice(
(activePage - 1) * ITEMS_PER_PAGE,
activePage * ITEMS_PER_PAGE
)
.map((log, idx) => {
if (log.deleted) return <></>;
return (
<Table.Row key={log.created_at}>
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
{
showModePanel && (
<Table.Cell><Label>{log.user_id}</Label></Table.Cell>
)
}
<Table.Cell>{renderType(log.type)}</Table.Cell>
<Table.Cell>{log.content}</Table.Cell>
</Table.Row>
);
})}
</Table.Body>
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan={showModePanel ? '5' : '4'}>
{
showModePanel && (
<Select
placeholder='选择模式'
options={MODE_OPTIONS}
style={{ marginRight: '8px' }}
name='mode'
value={mode}
onChange={(e, { name, value }) => {
setMode(value);
}}
/>
)
}
<Select
placeholder='选择明细分类'
options={LOG_OPTIONS}
style={{ marginRight: '8px' }}
name='logType'
value={logType}
onChange={(e, { name, value }) => {
setLogType(value);
}}
/>
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
<Pagination
floated='right'
activePage={activePage}
onPageChange={onPaginationChange}
size='small'
siblingRange={1}
totalPages={
Math.ceil(logs.length / ITEMS_PER_PAGE) +
(logs.length % ITEMS_PER_PAGE === 0 ? 1 : 0)
}
/>
</Table.HeaderCell>
</Table.Row>
</Table.Footer>
</Table>
</>
);
};
export default LogsTable;

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Divider, Form, Grid, Header, Modal } from 'semantic-ui-react';
import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react';
import { API, showError, showSuccess } from '../helpers';
import { marked } from 'marked';
@@ -10,13 +10,13 @@ const OtherSetting = () => {
About: '',
SystemName: '',
Logo: '',
HomePageContent: '',
HomePageContent: ''
});
let [loading, setLoading] = useState(false);
const [showUpdateModal, setShowUpdateModal] = useState(false);
const [updateData, setUpdateData] = useState({
tag_name: '',
content: '',
content: ''
});
const getOptions = async () => {
@@ -43,7 +43,7 @@ const OtherSetting = () => {
setLoading(true);
const res = await API.put('/api/option/', {
key,
value,
value
});
const { success, message } = res.data;
if (success) {
@@ -97,7 +97,7 @@ const OtherSetting = () => {
} else {
setUpdateData({
tag_name: tag_name,
content: marked.parse(body),
content: marked.parse(body)
});
setShowUpdateModal(true);
}
@@ -153,7 +153,7 @@ const OtherSetting = () => {
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
/>
</Form.Group>
<Form.Button onClick={()=>submitOption('HomePageContent')}>保存首页内容</Form.Button>
<Form.Button onClick={() => submitOption('HomePageContent')}>保存首页内容</Form.Button>
<Form.Group widths='equal'>
<Form.TextArea
label='关于'
@@ -165,6 +165,7 @@ const OtherSetting = () => {
/>
</Form.Group>
<Form.Button onClick={submitAbout}>保存关于</Form.Button>
<Message>移除 One API 的版权标识必须首先获得授权后续版本将通过授权码强制执行</Message>
<Form.Group widths='equal'>
<Form.Input
label='页脚'

View File

@@ -112,13 +112,17 @@ const PersonalSetting = () => {
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
<Divider />
<Header as='h3'>账号绑定</Header>
<Button
onClick={() => {
setShowWeChatBindModal(true);
}}
>
绑定微信账号
</Button>
{
status.wechat_login && (
<Button
onClick={() => {
setShowWeChatBindModal(true);
}}
>
绑定微信账号
</Button>
)
}
<Modal
onClose={() => setShowWeChatBindModal(false)}
onOpen={() => setShowWeChatBindModal(true)}
@@ -148,7 +152,11 @@ const PersonalSetting = () => {
</Modal.Description>
</Modal.Content>
</Modal>
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
{
status.github_oauth && (
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
)
}
<Button
onClick={() => {
setShowEmailBindModal(true);

View File

@@ -30,9 +30,11 @@ const SystemSetting = () => {
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
GroupRatio: '',
TopUpLink: '',
AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0,
LogConsumeEnabled: '',
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -67,6 +69,7 @@ const SystemSetting = () => {
case 'TurnstileCheckEnabled':
case 'RegisterEnabled':
case 'AutomaticDisableChannelEnabled':
case 'LogConsumeEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
@@ -101,6 +104,7 @@ const SystemSetting = () => {
name === 'QuotaRemindThreshold' ||
name === 'PreConsumedQuota' ||
name === 'ModelRatio' ||
name === 'GroupRatio' ||
name === 'TopUpLink'
) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -131,6 +135,13 @@ const SystemSetting = () => {
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
@@ -329,6 +340,23 @@ const SystemSetting = () => {
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='分组倍率'
name='GroupRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
value={inputs.GroupRatio}
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
/>
</Form.Group>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
<Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
<Divider />
<Header as='h3'>

View File

@@ -238,11 +238,12 @@ const TokensTable = () => {
size={'small'}
positive
onClick={async () => {
if (await copy(token.key)) {
let key = "sk-" + token.key;
if (await copy(key)) {
showSuccess('已复制到剪贴板!');
} else {
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
setSearchKeyword(token.key);
setSearchKeyword(key);
}
}}
>

View File

@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber, renderText } from '../helpers/render';
function renderRole(role) {
switch (role) {
@@ -64,7 +65,7 @@ const UsersTable = () => {
(async () => {
const res = await API.post('/api/user/manage', {
username,
action,
action
});
const { success, message } = res.data;
if (success) {
@@ -158,6 +159,14 @@ const UsersTable = () => {
<Table basic>
<Table.Header>
<Table.Row>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortUser('id');
}}
>
ID
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
@@ -169,10 +178,10 @@ const UsersTable = () => {
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortUser('display_name');
sortUser('group');
}}
>
显示名称
分组
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
@@ -188,7 +197,7 @@ const UsersTable = () => {
sortUser('quota');
}}
>
剩余额度
统计信息
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
@@ -220,10 +229,23 @@ const UsersTable = () => {
if (user.deleted) return <></>;
return (
<Table.Row key={user.id}>
<Table.Cell>{user.username}</Table.Cell>
<Table.Cell>{user.display_name}</Table.Cell>
<Table.Cell>{user.email ? user.email : '无'}</Table.Cell>
<Table.Cell>{user.quota}</Table.Cell>
<Table.Cell>{user.id}</Table.Cell>
<Table.Cell>
<Popup
content={user.email ? user.email : '未绑定邮箱地址'}
key={user.display_name}
header={user.display_name ? user.display_name : user.username}
trigger={<span>{renderText(user.username, 10)}</span>}
hoverable
/>
</Table.Cell>
<Table.Cell>{renderGroup(user.group)}</Table.Cell>
<Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
<Table.Cell>
<Popup content='剩余额度' trigger={<Label>{renderNumber(user.quota)}</Label>} />
<Popup content='已用额度' trigger={<Label>{renderNumber(user.used_quota)}</Label>} />
<Popup content='请求次数' trigger={<Label>{renderNumber(user.request_count)}</Label>} />
</Table.Cell>
<Table.Cell>{renderRole(user.role)}</Table.Cell>
<Table.Cell>{renderStatus(user.status)}</Table.Cell>
<Table.Cell>
@@ -284,7 +306,6 @@ const UsersTable = () => {
size={'small'}
as={Link}
to={'/user/edit/' + user.id}
disabled={user.role === 100}
>
编辑
</Button>
@@ -297,7 +318,7 @@ const UsersTable = () => {
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan='7'>
<Table.HeaderCell colSpan='8'>
<Button size='small' as={Link} to='/user/add' loading={loading}>
添加新的用户
</Button>

View File

@@ -1,10 +1,13 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
{ key: 8, text: '自定义', value: 8, color: 'pink' },
{ key: 3, text: 'Azure', value: 3, color: 'olive' },
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' },
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
{ key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
{ key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
{ key: 8, text: '自定义', value: 8, color: 'pink' }
{ key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
{ key: 10, text: 'AI Proxy', value: 10, color: 'purple' },
{ key: 12, text: 'API2GPT', value: 12, color: 'blue' }
];

38
web/src/helpers/render.js Normal file
View File

@@ -0,0 +1,38 @@
import { Label } from 'semantic-ui-react';
export function renderText(text, limit) {
if (text.length > limit) {
return text.slice(0, limit - 3) + '...';
}
return text;
}
export function renderGroup(group) {
if (group === '') {
return <Label>default</Label>;
}
let groups = group.split(',');
groups.sort();
return <>
{groups.map((group) => {
if (group === 'vip' || group === 'pro') {
return <Label color='yellow'>{group}</Label>;
} else if (group === 'svip' || group === 'premium') {
return <Label color='red'>{group}</Label>;
}
return <Label>{group}</Label>;
})}
</>;
}
export function renderNumber(num) {
if (num >= 1000000000) {
return (num / 1000000000).toFixed(1) + 'B';
} else if (num >= 1000000) {
return (num / 1000000).toFixed(1) + 'M';
} else if (num >= 10000) {
return (num / 1000).toFixed(1) + 'k';
} else {
return num;
}
}

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../../helpers';
import { API, showError, showInfo, showSuccess } from '../../helpers';
import { CHANNEL_OPTIONS } from '../../constants';
const EditChannel = () => {
@@ -14,12 +14,17 @@ const EditChannel = () => {
type: 1,
key: '',
base_url: '',
other: ''
other: '',
models: [],
groups: ['default']
};
const [batch, setBatch] = useState(false);
const [inputs, setInputs] = useState(originInputs);
const [modelOptions, setModelOptions] = useState([]);
const [groupOptions, setGroupOptions] = useState([]);
const [basicModels, setBasicModels] = useState([]);
const [fullModels, setFullModels] = useState([]);
const handleInputChange = (e, { name, value }) => {
console.log(name, value);
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
@@ -27,26 +32,74 @@ const EditChannel = () => {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
if (success) {
data.password = '';
if (data.models === "") {
data.models = []
} else {
data.models = data.models.split(",")
}
if (data.group === "") {
data.groups = []
} else {
data.groups = data.group.split(",")
}
setInputs(data);
} else {
showError(message);
}
setLoading(false);
};
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
setModelOptions(res.data.data.map((model) => ({
key: model.id,
text: model.id,
value: model.id,
})));
setFullModels(res.data.data.map((model) => model.id));
setBasicModels(res.data.data.filter((model) => !model.id.startsWith("gpt-4")).map((model) => model.id));
} catch (error) {
showError(error.message);
}
};
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
setGroupOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group,
})));
} catch (error) {
showError(error.message);
}
};
useEffect(() => {
if (isEdit) {
loadChannel().then();
}
fetchModels().then();
fetchGroups().then();
}, []);
const submit = async () => {
if (!isEdit && (inputs.name === '' || inputs.key === '')) return;
if (!isEdit && (inputs.name === '' || inputs.key === '')) {
showInfo('请填写渠道名称和渠道密钥!');
return;
}
let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
}
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-03-15-preview';
}
let res;
localInputs.models = localInputs.models.join(",")
localInputs.group = localInputs.groups.join(",")
if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else {
@@ -83,7 +136,9 @@ const EditChannel = () => {
inputs.type === 3 && (
<>
<Message>
注意<strong>模型部署名称必须和模型名称保持一致</strong> One API model
注意<strong>模型部署名称必须和模型名称保持一致</strong> One API model
参数替换为你的部署名称模型名称中的点会被剔除<a target='_blank'
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>
</Message>
<Form.Field>
<Form.Input
@@ -132,6 +187,58 @@ const EditChannel = () => {
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.Dropdown
label='分组'
placeholder={'请选择分组'}
name='groups'
fluid
multiple
selection
allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={handleInputChange}
value={inputs.groups}
autoComplete='new-password'
options={groupOptions}
/>
</Form.Field>
<Form.Field>
<Form.Dropdown
label='模型'
placeholder={'请选择该通道所支持的模型'}
name='models'
fluid
multiple
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
<div style={{ lineHeight: '40px', marginBottom: '12px'}}>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: basicModels });
}}>填入基础模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: fullModels });
}}>填入所有模型</Button>
</div>
{
inputs.type === 1 && (
<Form.Field>
<Form.Input
label='代理'
name='base_url'
placeholder={'请输入 OpenAI API 代理地址如果不需要请留空格式为https://api.openai.com'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
/>
</Form.Field>
)
}
{
batch ? <Form.Field>
<Form.TextArea
@@ -151,7 +258,7 @@ const EditChannel = () => {
onChange={handleInputChange}
value={inputs.key}
autoComplete='new-password'
/>
/>
</Form.Field>
}
{
@@ -164,7 +271,7 @@ const EditChannel = () => {
/>
)
}
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -0,0 +1,14 @@
import React from 'react';
import { Header, Segment } from 'semantic-ui-react';
import LogsTable from '../../components/LogsTable';
const Token = () => (
<>
<Segment>
<Header as='h3'>额度明细</Header>
<LogsTable />
</Segment>
</>
);
export default Token;

View File

@@ -111,7 +111,7 @@ const EditRedemption = () => {
</Form.Field>
</>
}
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -106,6 +106,34 @@ const EditToken = () => {
required={!isEdit}
/>
</Form.Field>
<Form.Field>
<Form.Input
label='过期时间'
name='expired_time'
placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制'}
onChange={handleInputChange}
value={expired_time}
autoComplete='new-password'
type='datetime-local'
/>
</Form.Field>
<div style={{ lineHeight: '40px' }}>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 0);
}}>永不过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(1, 0, 0, 0);
}}>一个月后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 1, 0, 0);
}}>一天后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 1, 0);
}}>一小时后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 1);
}}>一分钟后过期</Button>
</div>
<Message>注意令牌的额度仅用于限制令牌本身的最大额度使用量实际的使用受到账户的剩余额度限制</Message>
<Form.Field>
<Form.Input
@@ -119,36 +147,10 @@ const EditToken = () => {
disabled={unlimited_quota}
/>
</Form.Field>
<Button type={'button'} style={{ marginBottom: '14px' }} onClick={() => {
<Button type={'button'} onClick={() => {
setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
<Form.Field>
<Form.Input
label='过期时间'
name='expired_time'
placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制'}
onChange={handleInputChange}
value={expired_time}
autoComplete='new-password'
type='datetime-local'
/>
</Form.Field>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 0);
}}>永不过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(1, 0, 0, 0);
}}>一个月后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 1, 0, 0);
}}>一天后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 1, 0);
}}>一小时后过期</Button>
<Button type={'button'} onClick={() => {
setExpiredTime(0, 0, 0, 1);
}}>一分钟后过期</Button>
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
import { API, showError, showSuccess } from '../../helpers';
import { API, showError, showInfo, showSuccess } from '../../helpers';
const TopUp = () => {
const [redemptionCode, setRedemptionCode] = useState('');
@@ -9,6 +9,7 @@ const TopUp = () => {
const topUp = async () => {
if (redemptionCode === '') {
showInfo('请输入充值码!')
return;
}
const res = await API.post('/api/user/topup', {
@@ -80,7 +81,7 @@ const TopUp = () => {
<Grid.Column>
<Statistic.Group widths='one'>
<Statistic>
<Statistic.Value>{userQuota}</Statistic.Value>
<Statistic.Value>{userQuota.toLocaleString()}</Statistic.Value>
<Statistic.Label>剩余额度</Statistic.Label>
</Statistic>
</Statistic.Group>

View File

@@ -65,7 +65,7 @@ const AddUser = () => {
required
/>
</Form.Field>
<Button type={'submit'} onClick={submit}>
<Button positive type={'submit'} onClick={submit}>
提交
</Button>
</Form>

View File

@@ -14,12 +14,27 @@ const EditUser = () => {
github_id: '',
wechat_id: '',
email: '',
quota: 0,
group: 'default'
});
const { username, display_name, password, github_id, wechat_id, email } =
const [groupOptions, setGroupOptions] = useState([]);
const { username, display_name, password, github_id, wechat_id, email, quota, group } =
inputs;
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
setGroupOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group,
})));
} catch (error) {
showError(error.message);
}
};
const loadUser = async () => {
let res = undefined;
@@ -39,12 +54,19 @@ const EditUser = () => {
};
useEffect(() => {
loadUser().then();
if (userId) {
fetchGroups().then();
}
}, []);
const submit = async () => {
let res = undefined;
if (userId) {
res = await API.put(`/api/user/`, { ...inputs, id: parseInt(userId) });
let data = { ...inputs, id: parseInt(userId) };
if (typeof data.quota === 'string') {
data.quota = parseInt(data.quota);
}
res = await API.put(`/api/user/`, data);
} else {
res = await API.put(`/api/user/self`, inputs);
}
@@ -92,6 +114,37 @@ const EditUser = () => {
autoComplete='new-password'
/>
</Form.Field>
{
userId && <>
<Form.Field>
<Form.Dropdown
label='分组'
placeholder={'请选择分组'}
name='group'
fluid
search
selection
allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={handleInputChange}
value={inputs.group}
autoComplete='new-password'
options={groupOptions}
/>
</Form.Field>
<Form.Field>
<Form.Input
label='剩余额度'
name='quota'
placeholder={'请输入新的剩余额度'}
onChange={handleInputChange}
value={quota}
type={'number'}
autoComplete='new-password'
/>
</Form.Field>
</>
}
<Form.Field>
<Form.Input
label='已绑定的 GitHub 账户'
@@ -122,7 +175,7 @@ const EditUser = () => {
readOnly
/>
</Form.Field>
<Button onClick={submit}>提交</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>