Compare commits

...

52 Commits

Author SHA1 Message Date
JustSong
9a2662af0d feat: show token info when quota is not enough (close #1274) 2024-04-05 12:42:14 +08:00
JustSong
77f9e75654 fix: fix IsValidSubnet 2024-04-05 12:40:03 +08:00
JustSong
5b41f57423 feat: support stepfun's models 2024-04-05 12:32:05 +08:00
JustSong
0bb7db0b44 fix: do not detect quota field in error message (close #1276) 2024-04-05 12:11:50 +08:00
JustSong
4d61b9937b feat: support feishu login now 2024-04-05 12:10:43 +08:00
JustSong
68605800af feat: add subnet validation (#1275) 2024-04-05 10:18:42 +08:00
JustSong
c49778c254 feat: now able to limit ip range for token now (close #1275) 2024-04-05 10:09:16 +08:00
JustSong
f02c7138ea docs: update README 2024-04-05 01:35:14 +08:00
JustSong
ca3228855a docs: update API docs 2024-04-05 01:29:22 +08:00
JustSong
f8cc63f00b feat: add user info to topup link 2024-04-05 01:23:11 +08:00
JustSong
0a37aa4cbd docs: add API docs 2024-04-05 01:10:30 +08:00
JustSong
054b00b725 docs: add API docs 2024-04-05 00:40:48 +08:00
JustSong
76569bb0b6 chore: disable channel when error message contain credit or balance 2024-04-05 00:31:41 +08:00
JustSong
1994256bac chore: disable channel when error message contain quota 2024-04-05 00:18:26 +08:00
JustSong
1f80b0a39f chore: add omitempty for xunfei functions 2024-04-05 00:13:37 +08:00
manjieqi
f73f2e51df feat: update baidu model name & ratio (#1253)
* 修正百度模型名称

* 更新百度模型名称,并保留旧版兼容以及修正单价

* chore: add more model and adjust order

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-05 00:02:15 +08:00
Yang Fei
6f036bd0c9 feat: add embedding-2 support for zhipu (#1273)
* 增加对智谱embedding-2模型的支持

* fix: fix usage & ratio

---------

Co-authored-by: yangfei <yangfei@xuyao.info>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-04 23:32:59 +08:00
JustSong
fb90747c23 fix: fix /v1/models return null data when no models available 2024-04-04 18:53:42 +08:00
JustSong
ed70881a58 fix: fix token create 2024-04-04 11:18:21 +08:00
JustSong
8b9fa3d6e4 fix: fix GetGroupModels 2024-04-04 02:58:21 +08:00
JustSong
8b9813d63b feat: /v1/models now only return available models 2024-04-04 02:44:59 +08:00
JustSong
dc7aaf2de5 feat: able to set model limitation for token (close #178) 2024-04-04 02:08:18 +08:00
JustSong
065da8ef8c fix: fix ali function call (#1242) 2024-04-04 00:46:30 +08:00
JustSong
e3cfb1fa52 feat: use given usage if available in stream mode 2024-03-31 23:41:52 +08:00
JustSong
f89ae5ad58 feat: initial function call support for xunfei 2024-03-31 23:12:29 +08:00
JustSong
06a3fc5421 chore: update GeneralOpenAIRequest 2024-03-31 22:23:42 +08:00
ManJieqi
a9c464ec5a fix: update model-ratio.go 修正文心计费模型名称
统一文心计费模型名称
2024-03-30 11:06:31 +08:00
JustSong
3f3c13c98c feat: support top_k for claude (close #1239) 2024-03-30 10:47:07 +08:00
JustSong
2ba28c72cb feat: support function call for ali (close #1242) 2024-03-30 10:43:26 +08:00
JustSong
5e81e19bc8 fix: fix SQL channel selection algo (#1197) 2024-03-27 19:09:27 +08:00
JustSong
96d7a99312 fix: fix autofilled models are not correct 2024-03-24 23:12:32 +08:00
JustSong
24be9de098 chore: update copy 2024-03-24 23:01:03 +08:00
JustSong
5b349efff9 chore: fix berry copy 2024-03-24 22:57:24 +08:00
JustSong
f76c46d648 feat: add gemini-1.5-pro (#1211) 2024-03-24 22:50:09 +08:00
JustSong
cdfdeea3b4 feat: return token when calling post /api/token (close #1208) 2024-03-24 22:24:41 +08:00
JustSong
56ddbb842a fix: return pre-consumed quota when error happened for audio (close #1217) 2024-03-24 22:20:41 +08:00
JustSong
99f81a267c fix: fix xunfei error handling (close #1218) 2024-03-24 22:14:45 +08:00
xietong
c243cd5535 feat: 支持 ollama 的 embedding 接口 (#1221)
* 增加ollama的embedding接口

* chore: fix function name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-24 21:51:31 +08:00
GuangxiaoLong
e96b173abe feat: 移除 azure model 的 TrimSuffix (#1193) 2024-03-24 21:47:46 +08:00
Benny
4ae311e964 docs: update README (#1186) 2024-03-17 21:06:36 +08:00
JustSong
b14cb748d8 chore: update copy 2024-03-17 19:39:00 +08:00
Ian Li
ade19ba4a2 feat: update default API version for Azure OpenAI (#994)
* feat: Update default API version for Azure OpenAI.

* chore: update other theme

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-17 19:34:21 +08:00
Ian Li
4d86d021c4 feat: support Azure OpenAI TTS. (#1177) 2024-03-17 19:30:50 +08:00
shuirong
7a44adb5a7 fix: fix panel cards style (#1171) 2024-03-17 19:26:12 +08:00
Benny
9821bc7281 feat: add user list sorting and pagination enhancements (#1178)
* feat: add user list sorting and pagination enhancements

* feat: add user list sorting for THEME=air

* feat: add token list sorting and pagination enhancements

* feat: add token list sorting for THEME=air
2024-03-17 19:25:36 +08:00
JustSong
08831881f1 feat: increase initial root user quota and support INITIAL_ROOT_TOKEN now (#1105) 2024-03-17 19:09:44 +08:00
JustSong
0eb2272bb7 chore: update copy 2024-03-17 18:12:49 +08:00
JustSong
704ec1a827 chore: update theme berry 2024-03-17 17:48:57 +08:00
Ghostz
1d7470d6ad fix: fix lingyiwanwu model ratio (#1182) 2024-03-17 17:04:29 +08:00
JustSong
1185303346 chore: update comments 2024-03-17 14:10:35 +08:00
JustSong
c212fcf8d7 docs: update readme 2024-03-17 14:00:33 +08:00
JustSong
c285e000cc chore: remove default scroll bar 2024-03-16 16:16:44 +08:00
101 changed files with 1679 additions and 402 deletions

1
.gitignore vendored
View File

@@ -8,3 +8,4 @@ build
logs
data
/web/node_modules
cmd.md

View File

@@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `SESSION_SECRET=random_string`
3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0.
+ Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL.
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5`
### Command Line Parameters

View File

@@ -242,17 +242,18 @@ graph LR
+ 例: `SESSION_SECRET=random_string`
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
+ 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる
4. `LOG_SQL_DSN`: 設定ると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください
5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
+ 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
+ 例: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
+ 例: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
+ 例: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
+ 例: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
+ 例: `POLLING_INTERVAL=5`
### コマンドラインパラメータ

View File

@@ -81,13 +81,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
+ [x] [阶跃星辰](https://platform.stepfun.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。
6. 支持**令牌管理**,设置令牌的过期时间额度。
6. 支持**令牌管理**,设置令牌的过期时间额度、允许的 IP 范围以及允许的模型访问
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
8. 支持**道管理**,批量创建道。
8. 支持**道管理**,批量创建道。
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
10. 支持渠道**设置模型列表**。
11. 支持**查看额度明细**。
@@ -101,10 +102,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 支持使用飞书进行授权登录。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
@@ -349,38 +351,40 @@ graph LR
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置
4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL
5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`MEMORY_CACHE_ENABLED=true`
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
+ 例子:`SYNC_FREQUENCY=60`
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ 例子:`NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
13. 请求频率限制:
14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
17. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
18. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
23. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -419,7 +423,7 @@ https://openai.justsong.cn
+ 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游道 429 了。
+ 上游道 429 了。
7. 升级之后我的数据会丢失吗?
+ 如果使用 MySQL不会。
+ 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
@@ -427,8 +431,8 @@ https://openai.justsong.cn
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`
+ 这是检测到 ability 表里有些记录的道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的道。
+ 对于每一个道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该道支持该模型。
+ 这是检测到 ability 表里有些记录的道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的道。
+ 对于每一个道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该道支持该模型。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统

View File

@@ -66,6 +66,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -136,3 +139,5 @@ var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")

View File

@@ -71,6 +71,7 @@ const (
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeStepFun
ChannelTypeDummy
)
@@ -108,6 +109,7 @@ var ChannelBaseURLs = []string{
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
"https://api.stepfun.com", // 32
}
const (

6
common/conv/any.go Normal file
View File

@@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

View File

@@ -72,21 +72,34 @@ var ModelRatio = map[string]float64{
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-Bot-8K-0922": 0.024 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Lite-8K": 0.003 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"embedding-2": 0.0005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@@ -133,9 +146,9 @@ var ModelRatio = map[string]float64{
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000000 * RMB,
"yi-vl-plus": 6.0 / 1000000 * RMB,
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
}
var CompletionRatio = map[string]float64{}
@@ -248,6 +261,9 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7

25
common/network/ip.go Normal file
View File

@@ -0,0 +1,25 @@
package network
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"net"
)
func IsValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
if err != nil {
return fmt.Errorf("failed to parse subnet: %w", err)
}
return nil
}
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
_, ipNet, err := net.ParseCIDR(subnet)
if err != nil {
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
return false
}
return ipNet.Contains(net.ParseIP(ip))
}

19
common/network/ip_test.go Normal file
View File

@@ -0,0 +1,19 @@
package network
import (
"context"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestIsIpInSubnet(t *testing.T) {
ctx := context.Background()
ip1 := "192.168.0.5"
ip2 := "125.216.250.89"
subnet := "192.168.0.0/24"
Convey("TestIsIpInSubnet", t, func() {
So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
})
}

View File

@@ -1,4 +1,4 @@
package controller
package auth
import (
"bytes"
@@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) {
})
return
}
setupLogin(&user, c)
controller.SetupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {

201
controller/auth/lark.go Normal file
View File

@@ -0,0 +1,201 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type LarkOAuthResponse struct {
AccessToken string `json:"access_token"`
}
type LarkUser struct {
Name string `json:"name"`
OpenID string `json:"open_id"`
}
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.LarkClientId,
"client_secret": config.LarkClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LarkOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
var larkUser LarkUser
err = json.NewDecoder(res2.Body).Decode(&larkUser)
if err != nil {
return nil, err
}
return &larkUser, nil
}
func LarkOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LarkBind(c)
return
}
code := c.Query("code")
larkUser, err := getLarkUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LarkId: larkUser.OpenID,
}
if model.IsLarkIdAlreadyTaken(user.LarkId) {
err := user.FillUserByLarkId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
if larkUser.Name != "" {
user.DisplayName = larkUser.Name
} else {
user.DisplayName = "Lark User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func LarkBind(c *gin.Context) {
code := c.Query("code")
larkUser, err := getLarkUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LarkId: larkUser.OpenID,
}
if model.IsLarkIdAlreadyTaken(user.LarkId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该飞书账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LarkId = larkUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -1,4 +1,4 @@
package controller
package auth
import (
"encoding/json"
@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) {
})
return
}
setupLogin(&user, c)
controller.SetupLogin(&user, c)
}
func WeChatBind(c *gin.Context) {

View File

@@ -197,7 +197,7 @@ func testChannels(notify bool, scope string) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := message.Notify(message.ByAll, "道测试完成", "", "道测试完成,如果没有收到禁用通知,说明所有道都正常")
err := message.Notify(message.ByAll, "道测试完成", "", "道测试完成,如果没有收到禁用通知,说明所有道都正常")
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}

View File

@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,

View File

@@ -4,12 +4,14 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
"strings"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) {
}
func ListModels(c *gin.Context) {
ctx := c.Request.Context()
var availableModels []string
if c.GetString("available_models") != "" {
availableModels = strings.Split(c.GetString("available_models"), ",")
} else {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
}
modelSet := make(map[string]bool)
for _, availableModel := range availableModels {
modelSet[availableModel] = true
}
availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range openAIModels {
if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model)
}
}
for modelName, ok := range modelSet {
if ok {
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Root: modelName,
Parent: nil,
})
}
}
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
"data": availableOpenAIModels,
})
}
@@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) {
})
}
}
func GetUserAvailableModels(c *gin.Context) {
ctx := c.Request.Context()
id := c.GetInt("id")
userGroup, err := model.CacheGetUserGroup(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models, err := model.CacheGetGroupModels(ctx, userGroup)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": models,
})
return
}

View File

@@ -1,10 +1,12 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -16,7 +18,10 @@ func GetAllTokens(c *gin.Context) {
if p < 0 {
p = 0
}
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage)
order := c.Query("order")
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -101,6 +106,19 @@ func GetTokenStatus(c *gin.Context) {
})
}
func validateToken(c *gin.Context, token model.Token) error {
if len(token.Name) > 30 {
return fmt.Errorf("令牌名称过长")
}
if token.Subnet != nil && *token.Subnet != "" {
err := network.IsValidSubnet(*token.Subnet)
if err != nil {
return fmt.Errorf("无效的网段:%s", err.Error())
}
}
return nil
}
func AddToken(c *gin.Context) {
token := model.Token{}
err := c.ShouldBindJSON(&token)
@@ -111,13 +129,15 @@ func AddToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
err = validateToken(c, token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
"message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
@@ -127,6 +147,8 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
Models: token.Models,
Subnet: token.Subnet,
}
err = cleanToken.Insert()
if err != nil {
@@ -139,6 +161,7 @@ func AddToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": cleanToken,
})
return
}
@@ -173,10 +196,11 @@ func UpdateToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
err = validateToken(c, token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
"message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
@@ -212,6 +236,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.Models = token.Models
cleanToken.Subnet = token.Subnet
}
err = cleanToken.Update()
if err != nil {

View File

@@ -58,11 +58,11 @@ func Login(c *gin.Context) {
})
return
}
setupLogin(&user, c)
SetupLogin(&user, c)
}
// setup session & cookies and then return user info
func setupLogin(user *model.User, c *gin.Context) {
func SetupLogin(user *model.User, c *gin.Context) {
session := sessions.Default(c)
session.Set("id", user.Id)
session.Set("username", user.Username)
@@ -184,7 +184,10 @@ func GetAllUsers(c *gin.Context) {
if p < 0 {
p = 0
}
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage)
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -192,12 +195,12 @@ func GetAllUsers(c *gin.Context) {
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
return
}
func SearchUsers(c *gin.Context) {
@@ -767,3 +770,38 @@ func TopUp(c *gin.Context) {
})
return
}
type adminTopUpRequest struct {
UserId int `json:"user_id"`
Quota int `json:"quota"`
Remark string `json:"remark"`
}
func AdminTopUp(c *gin.Context) {
req := adminTopUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if req.Remark == "" {
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
}
model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

53
docs/API.md Normal file
View File

@@ -0,0 +1,53 @@
# 使用 API 操控 & 扩展 One API
> 欢迎提交 PR 在此放上你的拓展项目。
例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
## 鉴权
One API 支持两种鉴权方式Cookie 和 Token对于 Token参照下图获取
![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c)
之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API
![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8)
## 请求格式与响应格式
One API 使用 JSON 格式进行请求和响应。
对于响应体,一般格式如下:
```json
{
"message": "请求信息",
"success": true,
"data": {}
}
```
## API 列表
> 当前 API 列表不全,请自行通过浏览器抓取前端请求
如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
### 获取当前登录用户信息
**GET** `/api/user/self`
### 为给定用户充值额度
**POST** `/api/topup`
```json
{
"user_id": 1,
"quota": 100000,
"remark": "充值 100000 额度"
}
```
## 其他
### 充值链接上的附加参数
One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如:
`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837`
你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。
注意,不是所有主题都支持该功能,欢迎 PR 补齐。

4
go.mod
View File

@@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
@@ -37,6 +38,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -47,6 +49,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
@@ -55,6 +58,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect

12
go.sum
View File

@@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@@ -85,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@@ -127,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -177,8 +185,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=

View File

@@ -8,12 +8,12 @@
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running",
"响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs",
"道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!",
"返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!",
"管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub",
@@ -119,11 +119,11 @@
" 个月 ": " M ",
" 年 ": " y ",
"未测试": "Not tested",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用道余额!": "The balance of all enabled channels has been updated!",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用道余额!": "The balance of all enabled channels has been updated!",
"搜索渠道的 ID名称和密钥 ...": "Search for channel ID, name and key ...",
"名称": "Name",
"分组": "Group",
@@ -141,9 +141,9 @@
"启用": "Enable",
"编辑": "Edit",
"添加新的渠道": "Add a new channel",
"测试所有道": "Test all channels",
"测试所有已启用道": "Test all enabled channels",
"更新所有已启用道余额": "Update the balance of all enabled channels",
"测试所有道": "Test all channels",
"测试所有已启用道": "Test all enabled channels",
"更新所有已启用道余额": "Update the balance of all enabled channels",
"刷新": "Refresh",
"处理中...": "Processing...",
"绑定成功!": "Binding succeeded!",
@@ -207,11 +207,11 @@
"监控设置": "Monitoring Settings",
"最长响应时间": "Longest Response Time",
"单位秒": "Unit in seconds",
"当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"额度提醒阈值": "Quota reminder threshold",
"低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this",
"失败时自动禁用道": "Automatically disable the channel when it fails",
"失败时自动禁用道": "Automatically disable the channel when it fails",
"保存监控设置": "Save Monitoring Settings",
"额度设置": "Quota Settings",
"新用户初始额度": "Initial quota for new users",
@@ -405,7 +405,7 @@
"镜像": "Mirror",
"请输入镜像站地址格式为https://domain.com可不填不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used",
"模型": "Model",
"请选择该道所支持的模型": "Please select the model supported by the channel",
"请选择该道所支持的模型": "Please select the model supported by the channel",
"填入基础模型": "Fill in the basic model",
"填入所有模型": "Fill in all models",
"清除所有模型": "Clear all models",
@@ -515,7 +515,7 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为道命名": "Please name the channel",
"请为道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",

View File

@@ -1,10 +1,12 @@
package middleware
import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strings"
@@ -88,6 +90,7 @@ func RootAuth() func(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
ctx := c.Request.Context()
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
@@ -98,6 +101,12 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s当前 ip%s", *token.Subnet, c.ClientIP()))
return
}
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
@@ -107,6 +116,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
requestModel, err := getRequestModel(c)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
c.Set("request_model", requestModel)
if token.Models != nil && *token.Models != "" {
c.Set("available_models", *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
return
}
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)

View File

@@ -2,14 +2,12 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type ModelRequest struct {
@@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
requestModel := c.GetString("request_model")
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"

View File

@@ -1,9 +1,12 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
logger.Error(c.Request.Context(), message)
}
func getRequestModel(c *gin.Context) (string, error) {
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
return modelRequest.Model, nil
}
func isModelInList(modelName string, models string) bool {
modelList := strings.Split(models, ",")
for _, model := range modelList {
if modelName == model {
return true
}
}
return false
}

View File

@@ -1,7 +1,10 @@
package model
import (
"context"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
"sort"
"strings"
)
@@ -13,7 +16,7 @@ type Ability struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
@@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
var channelQuery *gorm.DB
if ignoreFirstPriority {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
} else {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
@@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var models []string
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
if err != nil {
return nil, err
}
sort.Strings(models)
return models, err
}

View File

@@ -21,6 +21,7 @@ var (
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
GroupModelsCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
if !common.RedisEnabled {
return GetGroupModels(ctx, group)
}
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
if err == nil {
return strings.Split(modelsStr, ","), nil
}
models, err := GetGroupModels(ctx, group)
if err != nil {
return nil, err
}
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set group models error: " + err.Error())
}
return models, nil
}
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex
@@ -205,7 +225,7 @@ func SyncChannelCache(frequency int) {
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()

View File

@@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordTopupLog(userId int, content string, quota int) {
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: helper.GetTimestamp(),
Type: LogTypeTopup,
Content: content,
Quota: quota,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {

View File

@@ -23,7 +23,7 @@ func CreateRootAccountIfNeed() error {
var user User
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -35,9 +35,25 @@ func CreateRootAccountIfNeed() error {
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: helper.GetUUID(),
Quota: 100000000,
Quota: 500000000000000,
}
DB.Create(&rootUser)
if config.InitialRootToken != "" {
logger.SysLog("creating initial root token as requested")
token := Token{
Id: 1,
UserId: rootUser.Id,
Key: config.InitialRootToken,
Status: common.TokenStatusEnabled,
Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: 500000000000000,
UnlimitedQuota: true,
}
DB.Create(&token)
}
}
return nil
}

View File

@@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) {
config.GitHubClientId = value
case "GitHubClientSecret":
config.GitHubClientSecret = value
case "LarkClientId":
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "Footer":
config.Footer = value
case "SystemName":

View File

@@ -14,7 +14,7 @@ type Redemption struct {
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int64 `json:"quota" gorm:"default:100"`
Quota int64 `json:"quota" gorm:"bigint;default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request

View File

@@ -12,23 +12,36 @@ import (
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"default:''"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
query := DB.Where("user_id = ?", userId)
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
case "used_quota":
query = query.Order("used_quota desc")
default:
query = query.Order("id desc")
}
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -50,7 +63,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("令牌验证失败")
}
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("令牌额度已用尽")
return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id)
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
@@ -110,7 +123,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
return err
}

View File

@@ -24,11 +24,12 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"type:int;default:0"`
UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Quota int64 `json:"quota" gorm:"bigint;default:0"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
@@ -40,8 +41,21 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
}
@@ -193,6 +207,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByLarkId() error {
if user.LarkId == "" {
return errors.New("lark id 为空!")
}
DB.Where(User{LarkId: user.LarkId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -221,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}

View File

@@ -31,17 +31,17 @@ func notifyRootUser(subject string, content string) {
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
subject := fmt.Sprintf("道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
subject := fmt.Sprintf("道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
notifyRootUser(subject, content)
}
@@ -49,7 +49,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
subject := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

View File

@@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
Tools: request.Tools,
},
}
}
@@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Choices: response.Output.Choices,
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
@@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.Delta = aliChoice.Message
if aliChoice.FinishReason != "null" {
finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
@@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
@@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,5 +1,10 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
@@ -11,13 +16,15 @@ type Input struct {
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
@@ -62,8 +69,9 @@ type Usage struct {
}
type Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {

View File

@@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {

View File

@@ -38,16 +38,26 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
@@ -59,7 +69,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
case "tao-8k":
suffix += "tao_8k"
default:
suffix += meta.ActualModelName
suffix += strings.ToLower(meta.ActualModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string

View File

@@ -1,11 +1,18 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"ERNIE-4.0-8K",
"ERNIE-Bot-8K-0922",
"ERNIE-3.5-8K",
"ERNIE-Lite-8K-0922",
"ERNIE-Speed-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Lite-8K",
"ERNIE-Speed-128K",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",

View File

@@ -1,6 +1,8 @@
package gemini
// https://ai.google.dev/models/gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001",
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
}

View File

@@ -3,13 +3,14 @@ package ollama
import (
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
@@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
}
return fullRequestURL, nil
}
@@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
switch relayMode {
case constant.RelayModeEmbeddings:
return nil, errors.New("not supported")
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil
default:
return ConvertRequest(*request), nil
}
@@ -51,7 +56,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}

View File

@@ -5,6 +5,10 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
@@ -12,9 +16,6 @@ import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
@@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return nil, &usage
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var ollamaResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if ollamaResponse.Error != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: ollamaResponse.Error,
Type: "ollama_error",
Param: "",
Code: "ollama_error",
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, 1),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: 0},
}
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: response.Embedding,
})
return &openAIEmbeddingResponse
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := context.TODO()
var ollamaResponse ChatResponse

View File

@@ -35,3 +35,13 @@ type ChatResponse struct {
EvalDuration int `json:"eval_duration,omitempty"`
Error string `json:"error,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
}
type EmbeddingResponse struct {
Error string `json:"error,omitempty"`
Embedding []float64 `json:"embedding,omitempty"`
}

View File

@@ -31,11 +31,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := meta.ActualModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
//https://github.com/songquanpeng/one-api/issues/1191
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax:
@@ -73,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText, _ = StreamHandler(c, resp, meta.Mode)
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
if usage == nil {
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
}
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/channel/stepfun"
)
var CompatibleChannels = []int{
@@ -20,6 +21,7 @@ var CompatibleChannels = []int{
common.ChannelTypeMistral,
common.ChannelTypeGroq,
common.ChannelTypeLingYiWanWu,
common.ChannelTypeStepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -40,6 +42,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
case common.ChannelTypeStepFun:
return "stepfun", stepfun.ModelList
default:
return "openai", ModelList
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage

View File

@@ -118,12 +118,9 @@ type ImageResponse struct {
}
type ChatCompletionsStreamResponseChoice struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
Role string `json:"role,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index"`
Delta model.Message `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {

View File

@@ -0,0 +1,7 @@
package stepfun
var ModelList = []string{
"step-1-32k",
"step-1v-32k",
"step-1-200k",
}

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
@@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -26,7 +26,11 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
var lastToolCalls []model.Tool
for _, message := range request.Messages {
if message.ToolCalls != nil {
lastToolCalls = message.ToolCalls
}
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
@@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
if len(lastToolCalls) != 0 {
for _, toolCall := range lastToolCalls {
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
}
}
return &xunfeiRequest
}
func getToolCalls(response *ChatResponse) []model.Tool {
var toolCalls []model.Tool
if len(response.Payload.Choices.Text) == 0 {
return toolCalls
}
item := response.Payload.Choices.Text[0]
if item.FunctionCall == nil {
return toolCalls
}
toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", helper.GetUUID()),
Type: "function",
Function: *item.FunctionCall,
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []ChatResponseTextItem{
@@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
ToolCalls: getToolCalls(response),
},
FinishReason: constant.StopFinishReason,
}
@@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &constant.StopFinishReason
}
@@ -121,7 +151,7 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
}
common.SetEventStreamHeaders(c)
var usage model.Usage
@@ -151,7 +181,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil
}
var usage model.Usage
var content string
@@ -171,11 +201,7 @@ func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId strin
}
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
{
Content: "",
},
}
return openai.ErrorWrapper(err, "xunfei_empty_response_detected", http.StatusInternalServerError), nil
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
@@ -202,15 +228,21 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
if err != nil {
return nil, nil, err
}
_, msg, err := conn.ReadMessage()
if err != nil {
return nil, nil, err
}
dataChan := make(chan ChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
break
if msg == nil {
_, msg, err = conn.ReadMessage()
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
break
}
}
var response ChatResponse
err = json.Unmarshal(msg, &response)
@@ -218,6 +250,7 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
logger.SysError("error unmarshalling stream response: " + err.Error())
break
}
msg = nil
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()

View File

@@ -26,13 +26,18 @@ type ChatRequest struct {
Message struct {
Text []Message `json:"text"`
} `json:"message"`
Functions struct {
Text []model.Function `json:"text,omitempty"`
} `json:"functions,omitempty"`
} `json:"payload"`
}
type ChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
ContentType string `json:"content_type"`
FunctionCall *model.Function `json:"function_call"`
}
type ChatResponse struct {

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
}
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
@@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
}
return ConvertRequest(*request), nil
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
@@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta)
}
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
if meta.Mode == constant.RelayModeEmbeddings {
err, usage = EmbeddingsHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
}
return
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "embedding-2",
Input: request.Input.(string),
}
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -2,5 +2,5 @@ package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
}

View File

@@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse EmbeddingRespone
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: response.Model,
Usage: model.Usage{
PromptTokens: response.PromptTokens,
CompletionTokens: response.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}

View File

@@ -44,3 +44,21 @@ type tokenData struct {
Token string
ExpiryTime time.Time
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input string `json:"input"`
}
type EmbeddingRespone struct {
Model string `json:"model"`
Object string `json:"object"`
Embeddings []EmbeddingData `json:"data"`
model.Usage `json:"usage"`
}
type EmbeddingData struct {
Index int `json:"index"`
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
}

View File

@@ -83,6 +83,24 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
@@ -104,10 +122,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
if channelType == common.ChannelTypeAzure {
apiVersion := util.GetAzureAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
if relayMode == constant.RelayModeAudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
} else if relayMode == constant.RelayModeAudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
}
}
requestBody := &bytes.Buffer{}
@@ -123,7 +146,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -188,20 +211,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
return util.RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)

View File

@@ -61,7 +61,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if meta.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
}

View File

@@ -5,25 +5,29 @@ type ResponseFormat struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
Model string `json:"model,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions any `json:"functions,omitempty"`
User string `json:"user,omitempty"`
Prompt any `json:"prompt,omitempty"`
Input any `json:"input,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -1,9 +1,10 @@
package model
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"`
}
func (m Message) IsStringContent() bool {

14
relay/model/tool.go Normal file
View File

@@ -0,0 +1,14 @@
package model
type Tool struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function Function `json:"function"`
}
type Function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response
}

View File

@@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}

View File

@@ -2,6 +2,7 @@ package router
import (
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/controller/auth"
"github.com/songquanpeng/one-api/middleware"
"github.com/gin-contrib/gzip"
@@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp)
userRoute := apiRouter.Group("/user")
{
@@ -43,6 +46,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
selfRoute.GET("/available_models", controller.GetUserAvailableModels)
}
adminRoute := userRoute.Group("/")

View File

@@ -2,6 +2,9 @@
> 每个文件夹代表一个主题,欢迎提交你的主题
> [!WARNING]
> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR
## 提交新的主题
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接
@@ -9,7 +12,7 @@
1.`web` 文件夹下新建一个文件夹,文件夹名为主题名。
2. 把你的主题文件放到这个文件夹下。
3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。
4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。
4. 修改 `common/config/config.go` 中的 `ValidThemes`,把你的主题名称注册进去。
5. 修改 `web/THEMES` 文件,这里也需要同步修改。
## 主题列表

View File

@@ -437,7 +437,7 @@ const ChannelsTable = () => {
if (success) {
record.response_time = time * 1000;
record.test_time = Date.now() / 1000;
showInfo(`${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
@@ -447,7 +447,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test?scope=${scope}`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试道,请刷新页面查看结果。');
showInfo('已成功开始测试道,请刷新页面查看结果。');
} else {
showError(message);
}
@@ -470,7 +470,7 @@ const ChannelsTable = () => {
if (success) {
record.balance = balance;
record.balance_updated_time = Date.now() / 1000;
showInfo(`${record.name} 余额更新成功!`);
showInfo(`${record.name} 余额更新成功!`);
} else {
showError(message);
}
@@ -481,7 +481,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用道余额!');
showInfo('已更新完毕所有已启用道余额!');
} else {
showError(message);
}
@@ -490,7 +490,7 @@ const ChannelsTable = () => {
const batchDeleteChannels = async () => {
if (selectedChannels.length === 0) {
showError('请先选择要删除的道!');
showError('请先选择要删除的道!');
return;
}
setLoading(true);
@@ -501,7 +501,7 @@ const ChannelsTable = () => {
const res = await API.post(`/api/channel/batch`, { ids: ids });
const { success, message, data } = res.data;
if (success) {
showSuccess(`已删除 ${data}道!`);
showSuccess(`已删除 ${data}道!`);
await refresh();
} else {
showError(message);
@@ -513,7 +513,7 @@ const ChannelsTable = () => {
const res = await API.post(`/api/channel/fix`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`已修复 ${data}道!`);
showSuccess(`已修复 ${data}道!`);
await refresh();
} else {
showError(message);
@@ -633,7 +633,7 @@ const ChannelsTable = () => {
onConfirm={() => { testChannels("all") }}
position={isMobile() ? 'top' : 'left'}
>
<Button theme="light" type="warning" style={{ marginRight: 8 }}>测试所有</Button>
<Button theme="light" type="warning" style={{ marginRight: 8 }}>测试所有</Button>
</Popconfirm>
<Popconfirm
title="确定?"
@@ -648,16 +648,16 @@ const ChannelsTable = () => {
okType={'secondary'}
onConfirm={updateAllChannelsBalance}
>
<Button theme="light" type="secondary" style={{ marginRight: 8 }}>更新所有已启用道余额</Button>
<Button theme="light" type="secondary" style={{ marginRight: 8 }}>更新所有已启用道余额</Button>
</Popconfirm> */}
<Popconfirm
title="确定是否要删除禁用道?"
title="确定是否要删除禁用道?"
content="此修改将不可逆"
okType={'danger'}
onConfirm={deleteAllDisabledChannels}
position={isMobile() ? 'top' : 'left'}
>
<Button theme="light" type="danger" style={{ marginRight: 8 }}>删除禁用</Button>
<Button theme="light" type="danger" style={{ marginRight: 8 }}>删除禁用</Button>
</Popconfirm>
<Button theme="light" type="primary" style={{ marginRight: 8 }} onClick={refresh}>刷新</Button>
@@ -673,7 +673,7 @@ const ChannelsTable = () => {
setEnableBatchDelete(v);
}}></Switch>
<Popconfirm
title="确定是否要删除所选道?"
title="确定是否要删除所选道?"
content="此修改将不可逆"
okType={'danger'}
onConfirm={batchDeleteChannels}
@@ -681,7 +681,7 @@ const ChannelsTable = () => {
position={'top'}
>
<Button disabled={!enableBatchDelete} theme="light" type="danger"
style={{ marginRight: 8 }}>删除所选道</Button>
style={{ marginRight: 8 }}>删除所选道</Button>
</Popconfirm>
<Popconfirm
title="确定是否要修复数据库一致性?"

View File

@@ -261,7 +261,7 @@ const OperationSetting = () => {
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
/>
<Form.Input
label='额度提醒阈值'
@@ -277,13 +277,13 @@ const OperationSetting = () => {
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用道'
label='失败时自动禁用道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label='成功时自动启用道'
label='成功时自动启用道'
name='AutomaticEnableChannelEnabled'
onChange={handleInputChange}
/>

View File

@@ -247,6 +247,8 @@ const TokensTable = () => {
const [editingToken, setEditingToken] = useState({
id: undefined
});
const [orderBy, setOrderBy] = useState('');
const [dropdownVisible, setDropdownVisible] = useState(false);
const closeEdit = () => {
setShowEdit(false);
@@ -269,7 +271,7 @@ const TokensTable = () => {
let pageData = tokens.slice((activePage - 1) * pageSize, activePage * pageSize);
const loadTokens = async (startIdx) => {
setLoading(true);
const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}`);
const res = await API.get(`/api/token/?p=${startIdx}&size=${pageSize}&order=${orderBy}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
@@ -289,7 +291,7 @@ const TokensTable = () => {
(async () => {
if (activePage === Math.ceil(tokens.length / pageSize) + 1) {
// In this case we have to load more data and then append them.
await loadTokens(activePage - 1);
await loadTokens(activePage - 1, orderBy);
}
setActivePage(activePage);
})();
@@ -392,12 +394,12 @@ const TokensTable = () => {
};
useEffect(() => {
loadTokens(0)
loadTokens(0, orderBy)
.then()
.catch((reason) => {
showError(reason);
});
}, [pageSize]);
}, [pageSize, orderBy]);
const removeRecord = key => {
let newDataSource = [...tokens];
@@ -452,6 +454,7 @@ const TokensTable = () => {
// if keyword is blank, load files instead.
await loadTokens(0);
setActivePage(1);
setOrderBy('');
return;
}
setSearching(true);
@@ -520,6 +523,23 @@ const TokensTable = () => {
}
};
const handleOrderByChange = (e, { value }) => {
setOrderBy(value);
setActivePage(1);
setDropdownVisible(false);
};
const renderSelectedOption = (orderBy) => {
switch (orderBy) {
case 'remain_quota':
return '按剩余额度排序';
case 'used_quota':
return '按已用额度排序';
default:
return '默认排序';
}
};
return (
<>
<EditToken refresh={refresh} editingToken={editingToken} visiable={showEdit} handleClose={closeEdit}></EditToken>
@@ -579,6 +599,21 @@ const TokensTable = () => {
await copyText(keys);
}
}>复制所选令牌到剪贴板</Button>
<Dropdown
trigger="click"
position="bottomLeft"
visible={dropdownVisible}
onVisibleChange={(visible) => setDropdownVisible(visible)}
render={
<Dropdown.Menu>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: '' })}>默认排序</Dropdown.Item>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: 'remain_quota' })}>按剩余额度排序</Dropdown.Item>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序</Dropdown.Item>
</Dropdown.Menu>
}
>
<Button style={{ marginLeft: '10px' }}>{renderSelectedOption(orderBy)}</Button>
</Dropdown>
</>
);
};

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { API, showError, showSuccess } from '../helpers';
import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip } from '@douyinfe/semi-ui';
import { Button, Form, Popconfirm, Space, Table, Tag, Tooltip, Dropdown } from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber, renderQuota } from '../helpers/render';
import AddUser from '../pages/User/AddUser';
@@ -139,6 +139,8 @@ const UsersTable = () => {
const [editingUser, setEditingUser] = useState({
id: undefined
});
const [orderBy, setOrderBy] = useState('');
const [dropdownVisible, setDropdownVisible] = useState(false);
const setCount = (data) => {
if (data.length >= (activePage) * ITEMS_PER_PAGE) {
@@ -162,7 +164,7 @@ const UsersTable = () => {
};
const loadUsers = async (startIdx) => {
const res = await API.get(`/api/user/?p=${startIdx}`);
const res = await API.get(`/api/user/?p=${startIdx}&order=${orderBy}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
@@ -184,19 +186,19 @@ const UsersTable = () => {
(async () => {
if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadUsers(activePage - 1);
await loadUsers(activePage - 1, orderBy);
}
setActivePage(activePage);
})();
};
useEffect(() => {
loadUsers(0)
loadUsers(0, orderBy)
.then()
.catch((reason) => {
showError(reason);
});
}, []);
}, [orderBy]);
const manageUser = async (username, action, record) => {
const res = await API.post('/api/user/manage', {
@@ -239,6 +241,7 @@ const UsersTable = () => {
// if keyword is blank, load files instead.
await loadUsers(0);
setActivePage(1);
setOrderBy('');
return;
}
setSearching(true);
@@ -301,6 +304,25 @@ const UsersTable = () => {
}
};
const handleOrderByChange = (e, { value }) => {
setOrderBy(value);
setActivePage(1);
setDropdownVisible(false);
};
const renderSelectedOption = (orderBy) => {
switch (orderBy) {
case 'quota':
return '按剩余额度排序';
case 'used_quota':
return '按已用额度排序';
case 'request_count':
return '按请求次数排序';
default:
return '默认排序';
}
};
return (
<>
<AddUser refresh={refresh} visible={showAddUser} handleClose={closeAddUser}></AddUser>
@@ -331,6 +353,22 @@ const UsersTable = () => {
setShowAddUser(true);
}
}>添加用户</Button>
<Dropdown
trigger="click"
position="bottomLeft"
visible={dropdownVisible}
onVisibleChange={(visible) => setDropdownVisible(visible)}
render={
<Dropdown.Menu>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: '' })}>默认排序</Dropdown.Item>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: 'quota' })}>按剩余额度排序</Dropdown.Item>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: 'used_quota' })}>按已用额度排序</Dropdown.Item>
<Dropdown.Item onClick={() => handleOrderByChange('', { value: 'request_count' })}>按请求次数排序</Dropdown.Item>
</Dropdown.Menu>
}
>
<Button style={{ marginLeft: '10px' }}>{renderSelectedOption(orderBy)}</Button>
</Dropdown>
</>
);
};

View File

@@ -103,3 +103,14 @@ code {
display: none !important;
}
}
/* 隐藏浏览器默认的滚动条 */
body {
overflow: hidden;
}
/* 自定义滚动条样式 */
body::-webkit-scrollbar {
width: 0; /* 隐藏滚动条的宽度 */
}

View File

@@ -230,7 +230,7 @@ const EditChannel = (props) => {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
}
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-06-01-preview';
localInputs.other = '2024-03-01-preview';
}
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
@@ -348,7 +348,7 @@ const EditChannel = (props) => {
<Input
label='默认 API 版本'
name='azure_other'
placeholder={'请输入默认 API 版本例如2023-06-01-preview该配置可以被实际的请求查询参数所覆盖'}
placeholder={'请输入默认 API 版本例如2024-03-01-preview该配置可以被实际的请求查询参数所覆盖'}
onChange={value => {
handleInputChange('other', value)
}}

View File

@@ -49,7 +49,7 @@ const typeConfig = {
base_url: "请填写AZURE_OPENAI_ENDPOINT",
// 注意:通过判断 `other` 是否有值来判断是否需要显示 `other` 输入框, 默认是没有值的
other: "请输入默认API版本例如2023-06-01-preview",
other: "请输入默认API版本例如2024-03-01-preview",
},
modelGroup: "openai", // 模型组名称,这个值是给 填入渠道支持模型 按钮使用的。 填入渠道支持模型 按钮会根据这个值来获取模型组,如果填写默认是 openai
},

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 4.2 KiB

View File

@@ -3,186 +3,192 @@ export const CHANNEL_OPTIONS = {
key: 1,
text: 'OpenAI',
value: 1,
color: 'primary'
color: 'success'
},
14: {
key: 14,
text: 'Anthropic Claude',
value: 14,
color: 'info'
color: 'primary'
},
3: {
key: 3,
text: 'Azure OpenAI',
value: 3,
color: 'secondary'
color: 'success'
},
11: {
key: 11,
text: 'Google PaLM2',
value: 11,
color: 'orange'
color: 'warning'
},
24: {
key: 24,
text: 'Google Gemini',
value: 24,
color: 'orange'
color: 'warning'
},
28: {
key: 28,
text: 'Mistral AI',
value: 28,
color: 'orange'
color: 'warning'
},
15: {
key: 15,
text: '百度文心千帆',
value: 15,
color: 'default'
color: 'primary'
},
17: {
key: 17,
text: '阿里通义千问',
value: 17,
color: 'default'
color: 'primary'
},
18: {
key: 18,
text: '讯飞星火认知',
value: 18,
color: 'default'
color: 'primary'
},
16: {
key: 16,
text: '智谱 ChatGLM',
value: 16,
color: 'default'
color: 'primary'
},
19: {
key: 19,
text: '360 智脑',
value: 19,
color: 'default'
color: 'primary'
},
25: {
key: 25,
text: 'Moonshot AI',
value: 25,
color: 'default'
color: 'primary'
},
23: {
key: 23,
text: '腾讯混元',
value: 23,
color: 'default'
color: 'primary'
},
26: {
key: 26,
text: '百川大模型',
value: 26,
color: 'default'
color: 'primary'
},
27: {
key: 27,
text: 'MiniMax',
value: 27,
color: 'default'
color: 'primary'
},
29: {
key: 29,
text: 'Groq',
value: 29,
color: 'default'
color: 'primary'
},
30: {
key: 30,
text: 'Ollama',
value: 30,
color: 'default'
color: 'primary'
},
31: {
key: 31,
text: '零一万物',
value: 31,
color: 'default'
color: 'primary'
},
32: {
key: 32,
text: '阶跃星辰',
value: 32,
color: 'primary'
},
8: {
key: 8,
text: '自定义渠道',
value: 8,
color: 'primary'
color: 'error'
},
22: {
key: 22,
text: '知识库FastGPT',
value: 22,
color: 'default'
color: 'success'
},
21: {
key: 21,
text: '知识库AI Proxy',
value: 21,
color: 'purple'
color: 'success'
},
20: {
key: 20,
text: '代理OpenRouter',
value: 20,
color: 'primary'
color: 'success'
},
2: {
key: 2,
text: '代理API2D',
value: 2,
color: 'primary'
color: 'success'
},
5: {
key: 5,
text: '代理OpenAI-SB',
value: 5,
color: 'primary'
color: 'success'
},
7: {
key: 7,
text: '代理OhMyGPT',
value: 7,
color: 'primary'
color: 'success'
},
10: {
key: 10,
text: '代理AI Proxy',
value: 10,
color: 'primary'
color: 'success'
},
4: {
key: 4,
text: '代理CloseAI',
value: 4,
color: 'primary'
color: 'success'
},
6: {
key: 6,
text: '代理OpenAI Max',
value: 6,
color: 'primary'
color: 'success'
},
9: {
key: 9,
text: '代理AI.LS',
value: 9,
color: 'primary'
color: 'success'
},
12: {
key: 12,
text: '代理API2GPT',
value: 12,
color: 'primary'
color: 'success'
},
13: {
key: 13,
text: '代理AIGC2D',
value: 13,
color: 'primary'
color: 'success'
}
};

View File

@@ -51,7 +51,7 @@ const Register = () => {
<Grid item xs={12}>
<Grid item container direction="column" alignItems="center" xs={12}>
<Typography component={Link} to="/login" variant="subtitle1" sx={{ textDecoration: 'none' }}>
已经有帐号了?点击登录
已经有帐号了点击登录
</Typography>
</Grid>
</Grid>

View File

@@ -180,7 +180,7 @@ const LoginForm = ({ ...others }) => {
{({ errors, handleBlur, handleChange, handleSubmit, isSubmitting, touched, values }) => (
<form noValidate onSubmit={handleSubmit} {...others}>
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.customInput }}>
<InputLabel htmlFor="outlined-adornment-username-login">用户名</InputLabel>
<InputLabel htmlFor="outlined-adornment-username-login">用户名 / 邮箱</InputLabel>
<OutlinedInput
id="outlined-adornment-username-login"
type="text"

View File

@@ -296,7 +296,7 @@ const RegisterForm = ({ ...others }) => {
<Box sx={{ mt: 2 }}>
<AnimateButton>
<Button disableElevation disabled={isSubmitting} fullWidth size="large" type="submit" variant="contained" color="primary">
Sign up
注册
</Button>
</AnimateButton>
</Box>

View File

@@ -3,6 +3,19 @@ import Label from "ui-component/Label";
import Stack from "@mui/material/Stack";
import Divider from "@mui/material/Divider";
function name2color(name) {
switch (name) {
case "default":
return "info";
case "vip":
return "warning"
case "svip":
return "error"
default:
return "info"
}
}
const GroupLabel = ({ group }) => {
let groups = [];
if (group === "") {
@@ -14,7 +27,7 @@ const GroupLabel = ({ group }) => {
return (
<Stack divider={<Divider orientation="vertical" flexItem />} spacing={0.5}>
{groups.map((group, index) => {
return <Label key={index}>{group}</Label>;
return <Label key={index} color={name2color(group)}>{group}</Label>;
})}
</Stack>
);

View File

@@ -10,6 +10,7 @@ const ChannelTableHead = () => {
<TableCell>类型</TableCell>
<TableCell>状态</TableCell>
<TableCell>响应时间</TableCell>
<TableCell>已消耗</TableCell>
<TableCell>余额</TableCell>
<TableCell>优先级</TableCell>
<TableCell>操作</TableCell>

View File

@@ -93,7 +93,7 @@ export default function ChannelTableRow({
test_time: Date.now() / 1000,
response_time: time * 1000,
});
showInfo(`${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
}
};
@@ -170,6 +170,9 @@ export default function ChannelTableRow({
handle_action={handleResponseTime}
/>
</TableCell>
<TableCell>
{renderNumber(item.used_quota)}
</TableCell>
<TableCell>
<Tooltip
title={"点击更新余额"}
@@ -240,9 +243,9 @@ export default function ChannelTableRow({
</Popover>
<Dialog open={openDelete} onClose={handleDeleteClose}>
<DialogTitle>删除</DialogTitle>
<DialogTitle>删除</DialogTitle>
<DialogContent>
<DialogContentText>是否删除 {item.name}</DialogContentText>
<DialogContentText>是否删除 {item.name}</DialogContentText>
</DialogContent>
<DialogActions>
<Button onClick={handleDeleteClose}>关闭</Button>

View File

@@ -135,7 +135,7 @@ export default function ChannelPage() {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试所有道,请刷新页面查看结果。');
showInfo('已成功开始测试所有道,请刷新页面查看结果。');
} else {
showError(message);
}
@@ -159,7 +159,7 @@ export default function ChannelPage() {
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用道余额!');
showInfo('已更新完毕所有已启用道余额!');
} else {
showError(message);
}
@@ -193,20 +193,14 @@ export default function ChannelPage() {
return (
<>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={2.5}>
<Typography variant="h4">渠道</Typography>
<Button variant="contained" color="primary" startIcon={<IconPlus />} onClick={() => handleOpenModal(0)}>
新建渠道
</Button>
</Stack>
<Stack mb={5}>
<Alert severity="info">
OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新
</Alert>
</Stack>
<Card>
<Box component="form" onSubmit={searchChannels} noValidate>
<Box component="form" onSubmit={searchChannels} noValidate sx={{marginTop: 2}}>
<TableToolBar filterName={searchKeyword} handleFilterName={handleSearchKeyword} placeholder={'搜索渠道的 ID名称和密钥 ...'} />
</Box>
<Toolbar
@@ -220,7 +214,7 @@ export default function ChannelPage() {
>
<Container>
{matchUpMd ? (
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
<ButtonGroup variant="outlined" aria-label="outlined small primary button group" sx={{marginBottom: 2}}>
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新
</Button>

View File

@@ -41,7 +41,7 @@ const typeConfig = {
},
prompt: {
base_url: "请填写AZURE_OPENAI_ENDPOINT",
other: "请输入默认API版本例如2023-06-01-preview",
other: "请输入默认API版本例如2024-03-01-preview",
},
},
11: {

View File

@@ -65,7 +65,7 @@ const StatisticalLineChartCard = ({ isLoading, title, chartData, todayValue }) =
) : (
<CardWrapper border={false} content={false}>
<Box sx={{ p: 2.25 }}>
<Grid container direction="column">
<Grid>
<Grid item sx={{ mb: 0.75 }}>
<Grid container alignItems="center">
<Grid item xs={6}>

View File

@@ -102,11 +102,11 @@ export default function Log() {
return (
<>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={2.5}>
<Typography variant="h4">日志</Typography>
</Stack>
<Card>
<Box component="form" onSubmit={searchLogs} noValidate>
<Box component="form" onSubmit={searchLogs} noValidate sx={{marginTop: 2}}>
<TableToolBar filterName={searchKeyword} handleFilterName={handleSearchKeyword} userIsAdmin={userIsAdmin} />
</Box>
<Toolbar
@@ -119,7 +119,7 @@ export default function Log() {
}}
>
<Container>
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
<ButtonGroup variant="outlined" aria-label="outlined small primary button group" sx={{marginBottom: 2}}>
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新/清除搜索条件
</Button>

View File

@@ -141,7 +141,7 @@ export default function Redemption() {
return (
<>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={2.5}>
<Typography variant="h4">兑换</Typography>
<Button variant="contained" color="primary" startIcon={<IconPlus />} onClick={() => handleOpenModal(0)}>
@@ -149,7 +149,7 @@ export default function Redemption() {
</Button>
</Stack>
<Card>
<Box component="form" onSubmit={searchRedemptions} noValidate>
<Box component="form" onSubmit={searchRedemptions} noValidate sx={{marginTop: 2}}>
<TableToolBar filterName={searchKeyword} handleFilterName={handleSearchKeyword} placeholder={'搜索兑换码的ID和名称...'} />
</Box>
<Toolbar
@@ -162,7 +162,7 @@ export default function Redemption() {
}}
>
<Container>
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
<ButtonGroup variant="outlined" aria-label="outlined small primary button group" sx={{marginBottom: 2}}>
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新
</Button>

View File

@@ -371,7 +371,7 @@ const OperationSetting = () => {
value={inputs.ChannelDisableThreshold}
onChange={handleInputChange}
label="最长响应时间"
placeholder="单位秒,当运行道全部测试时,超过此时间将自动禁用道"
placeholder="单位秒,当运行道全部测试时,超过此时间将自动禁用道"
disabled={loading}
/>
</FormControl>
@@ -392,7 +392,7 @@ const OperationSetting = () => {
</FormControl>
</Stack>
<FormControlLabel
label="失败时自动禁用道"
label="失败时自动禁用道"
control={
<Checkbox
checked={inputs.AutomaticDisableChannelEnabled === "true"}
@@ -402,7 +402,7 @@ const OperationSetting = () => {
}
/>
<FormControlLabel
label="成功时自动启用道"
label="成功时自动启用道"
control={
<Checkbox
checked={inputs.AutomaticEnableChannelEnabled === "true"}

View File

@@ -141,9 +141,8 @@ export default function Token() {
return (
<>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={2.5}>
<Typography variant="h4">令牌</Typography>
<Button
variant="contained"
color="primary"
@@ -155,13 +154,13 @@ export default function Token() {
新建令牌
</Button>
</Stack>
<Stack mb={5}>
<Stack mb={2}>
<Alert severity="info">
OpenAI API 基础地址 https://api.openai.com 替换为 <b>{siteInfo.server_address}</b>,复制下面的密钥即可使用
</Alert>
</Stack>
<Card>
<Box component="form" onSubmit={searchTokens} noValidate>
<Box component="form" onSubmit={searchTokens} noValidate sx={{marginTop: 2}}>
<TableToolBar filterName={searchKeyword} handleFilterName={handleSearchKeyword} placeholder={'搜索令牌的名称...'} />
</Box>
<Toolbar
@@ -174,7 +173,7 @@ export default function Token() {
}}
>
<Container>
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
<ButtonGroup variant="outlined" aria-label="outlined small primary button group" sx={{marginBottom: 2}}>
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新
</Button>

View File

@@ -139,7 +139,7 @@ export default function Users() {
return (
<>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={5}>
<Stack direction="row" alignItems="center" justifyContent="space-between" mb={2.5}>
<Typography variant="h4">用户</Typography>
<Button variant="contained" color="primary" startIcon={<IconPlus />} onClick={() => handleOpenModal(0)}>
@@ -147,7 +147,7 @@ export default function Users() {
</Button>
</Stack>
<Card>
<Box component="form" onSubmit={searchUsers} noValidate>
<Box component="form" onSubmit={searchUsers} noValidate sx={{marginTop: 2}}>
<TableToolBar
filterName={searchKeyword}
handleFilterName={handleSearchKeyword}
@@ -164,7 +164,7 @@ export default function Users() {
}}
>
<Container>
<ButtonGroup variant="outlined" aria-label="outlined small primary button group">
<ButtonGroup variant="outlined" aria-label="outlined small primary button group" sx={{marginBottom: 2}}>
<Button onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新
</Button>

View File

@@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
import Chat from './pages/Chat';
import LarkOAuth from './components/LarkOAuth';
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
@@ -239,6 +240,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/oauth/lark'
element={
<Suspense fallback={<Loading></Loading>}>
<LarkOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={

View File

@@ -234,7 +234,7 @@ const ChannelsTable = () => {
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
@@ -244,7 +244,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test?scope=${scope}`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试道,请刷新页面查看结果。');
showInfo('已成功开始测试道,请刷新页面查看结果。');
} else {
showError(message);
}
@@ -270,7 +270,7 @@ const ChannelsTable = () => {
newChannels[realIdx].balance = balance;
newChannels[realIdx].balance_updated_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`${name} 余额更新成功!`);
showInfo(`${name} 余额更新成功!`);
} else {
showError(message);
}
@@ -281,7 +281,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用道余额!');
showInfo('已更新完毕所有已启用道余额!');
} else {
showError(message);
}
@@ -333,6 +333,8 @@ const ChannelsTable = () => {
setPromptShown("channel-test");
}}>
OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新
<br/>
渠道测试仅支持 chat 模型优先使用 gpt-3.5-turbo如果该模型不可用则使用你所配置的模型列表中的第一个模型
</Message>
)
}

View File

@@ -0,0 +1,58 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const LarkOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind lark
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default LarkOAuth;

View File

@@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils';
import larkIcon from '../images/lark.svg';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -94,7 +95,7 @@ const LoginForm = () => {
fluid
icon='user'
iconPosition='left'
placeholder='用户名'
placeholder='用户名 / 邮箱地址'
name='username'
value={username}
onChange={handleChange}
@@ -124,7 +125,7 @@ const LoginForm = () => {
点击注册
</Link>
</Message>
{status.github_oauth || status.wechat_login ? (
{status.github_oauth || status.wechat_login || status.lark_client_id ? (
<>
<Divider horizontal>Or</Divider>
{status.github_oauth ? (
@@ -137,6 +138,18 @@ const LoginForm = () => {
) : (
<></>
)}
{status.lark_client_id ? (
<Button
// circular
color=''
onClick={() => onLarkOAuthClicked(status.lark_client_id)}
style={{ padding: 0, width: 36, height: 36 }}
>
<img src={larkIcon} width={36} height={36} />
</Button>
) : (
<></>
)}
{status.wechat_login ? (
<Button
circular

View File

@@ -261,7 +261,7 @@ const OperationSetting = () => {
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
/>
<Form.Input
label='额度提醒阈值'
@@ -277,13 +277,13 @@ const OperationSetting = () => {
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用道'
label='失败时自动禁用道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label='成功时自动启用道'
label='成功时自动启用道'
name='AutomaticEnableChannelEnabled'
onChange={handleInputChange}
/>

View File

@@ -4,7 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
import { onGitHubOAuthClicked } from './utils';
import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils';
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -247,6 +247,11 @@ const PersonalSetting = () => {
<Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
)
}
{
status.lark_client_id && (
<Button onClick={()=>{onLarkOAuthClicked(status.lark_client_id)}}>绑定飞书账号</Button>
)
}
<Button
onClick={() => {
setShowEmailBindModal(true);

View File

@@ -10,6 +10,8 @@ const SystemSetting = () => {
GitHubOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
LarkClientId: '',
LarkClientSecret: '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -109,6 +111,8 @@ const SystemSetting = () => {
name === 'ServerAddress' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
name === 'LarkClientId' ||
name === 'LarkClientSecret' ||
name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' ||
@@ -212,6 +216,18 @@ const SystemSetting = () => {
}
};
const submitLarkOAuth = async () => {
if (originInputs['LarkClientId'] !== inputs.LarkClientId) {
await updateOption('LarkClientId', inputs.LarkClientId);
}
if (
originInputs['LarkClientSecret'] !== inputs.LarkClientSecret &&
inputs.LarkClientSecret !== ''
) {
await updateOption('LarkClientSecret', inputs.LarkClientSecret);
}
};
const submitTurnstile = async () => {
if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) {
await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey);
@@ -469,6 +485,44 @@ const SystemSetting = () => {
保存 GitHub OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置飞书授权登录
<Header.Subheader>
用以支持通过飞书进行登录注册
<a href='https://open.feishu.cn/app' target='_blank'>
点击此处
</a>
管理你的飞书应用
</Header.Subheader>
</Header>
<Message>
主页链接填 <code>{inputs.ServerAddress}</code>
重定向 URL {' '}
<code>{`${inputs.ServerAddress}/oauth/lark`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='App ID'
name='LarkClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.LarkClientId}
placeholder='输入 App ID'
/>
<Form.Input
label='App Secret'
name='LarkClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.LarkClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitLarkOAuth}>
保存飞书 OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置 WeChat Server
<Header.Subheader>

View File

@@ -48,9 +48,10 @@ const TokensTable = () => {
const [searching, setSearching] = useState(false);
const [showTopUpModal, setShowTopUpModal] = useState(false);
const [targetTokenIdx, setTargetTokenIdx] = useState(0);
const [orderBy, setOrderBy] = useState('');
const loadTokens = async (startIdx) => {
const res = await API.get(`/api/token/?p=${startIdx}`);
const res = await API.get(`/api/token/?p=${startIdx}&order=${orderBy}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
@@ -70,7 +71,7 @@ const TokensTable = () => {
(async () => {
if (activePage === Math.ceil(tokens.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadTokens(activePage - 1);
await loadTokens(activePage - 1, orderBy);
}
setActivePage(activePage);
})();
@@ -160,12 +161,12 @@ const TokensTable = () => {
}
useEffect(() => {
loadTokens(0)
loadTokens(0, orderBy)
.then()
.catch((reason) => {
showError(reason);
});
}, []);
}, [orderBy]);
const manageToken = async (id, action, idx) => {
let data = { id };
@@ -205,6 +206,7 @@ const TokensTable = () => {
// if keyword is blank, load files instead.
await loadTokens(0);
setActivePage(1);
setOrderBy('');
return;
}
setSearching(true);
@@ -243,6 +245,11 @@ const TokensTable = () => {
setLoading(false);
};
const handleOrderByChange = (e, { value }) => {
setOrderBy(value);
setActivePage(1);
};
return (
<>
<Form onSubmit={searchTokens}>
@@ -427,6 +434,18 @@ const TokensTable = () => {
添加新的令牌
</Button>
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
<Dropdown
placeholder='排序方式'
selection
options={[
{ key: '', text: '默认排序', value: '' },
{ key: 'remain_quota', text: '按剩余额度排序', value: 'remain_quota' },
{ key: 'used_quota', text: '按已用额度排序', value: 'used_quota' },
]}
value={orderBy}
onChange={handleOrderByChange}
style={{ marginLeft: '10px' }}
/>
<Pagination
floated='right'
activePage={activePage}

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Button, Form, Label, Pagination, Popup, Table, Dropdown } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
@@ -25,9 +25,10 @@ const UsersTable = () => {
const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [orderBy, setOrderBy] = useState('');
const loadUsers = async (startIdx) => {
const res = await API.get(`/api/user/?p=${startIdx}`);
const res = await API.get(`/api/user/?p=${startIdx}&order=${orderBy}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
@@ -47,19 +48,19 @@ const UsersTable = () => {
(async () => {
if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadUsers(activePage - 1);
await loadUsers(activePage - 1, orderBy);
}
setActivePage(activePage);
})();
};
useEffect(() => {
loadUsers(0)
loadUsers(0, orderBy)
.then()
.catch((reason) => {
showError(reason);
});
}, []);
}, [orderBy]);
const manageUser = (username, action, idx) => {
(async () => {
@@ -110,6 +111,7 @@ const UsersTable = () => {
// if keyword is blank, load files instead.
await loadUsers(0);
setActivePage(1);
setOrderBy('');
return;
}
setSearching(true);
@@ -148,6 +150,11 @@ const UsersTable = () => {
setLoading(false);
};
const handleOrderByChange = (e, { value }) => {
setOrderBy(value);
setActivePage(1);
};
return (
<>
<Form onSubmit={searchUsers}>
@@ -322,6 +329,19 @@ const UsersTable = () => {
<Button size='small' as={Link} to='/user/add' loading={loading}>
添加新的用户
</Button>
<Dropdown
placeholder='排序方式'
selection
options={[
{ key: '', text: '默认排序', value: '' },
{ key: 'quota', text: '按剩余额度排序', value: 'quota' },
{ key: 'used_quota', text: '按已用额度排序', value: 'used_quota' },
{ key: 'request_count', text: '按请求次数排序', value: 'request_count' },
]}
value={orderBy}
onChange={handleOrderByChange}
style={{ marginLeft: '10px' }}
/>
<Pagination
floated='right'
activePage={activePage}

View File

@@ -17,4 +17,13 @@ export async function onGitHubOAuthClicked(github_client_id) {
window.open(
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
);
}
export async function onLarkOAuthClicked(lark_client_id) {
const state = await getOAuthState();
if (!state) return;
let redirect_uri = `${window.location.origin}/oauth/lark`;
window.open(
`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`
);
}

View File

@@ -17,6 +17,7 @@ export const CHANNEL_OPTIONS = [
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 31, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 5.4 KiB

View File

@@ -83,6 +83,7 @@ const EditChannel = () => {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
setInputs(data);
setBasicModels(getChannelModels(data.type));
} else {
showError(message);
}
@@ -99,9 +100,6 @@ const EditChannel = () => {
}));
setOriginModelOptions(localModelOptions);
setFullModels(res.data.data.map((model) => model.id));
setBasicModels(res.data.data.filter((model) => {
return model.id.startsWith('gpt-3') || model.id.startsWith('text-');
}).map((model) => model.id));
} catch (error) {
showError(error.message);
}
@@ -137,6 +135,9 @@ const EditChannel = () => {
useEffect(() => {
if (isEdit) {
loadChannel().then();
} else {
let localModels = getChannelModels(inputs.type);
setBasicModels(localModels);
}
fetchModels().then();
fetchGroups().then();
@@ -160,7 +161,7 @@ const EditChannel = () => {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
}
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-06-01-preview';
localInputs.other = '2024-03-01-preview';
}
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
@@ -242,7 +243,7 @@ const EditChannel = () => {
<Form.Input
label='默认 API 版本'
name='other'
placeholder={'请输入默认 API 版本例如2023-06-01-preview该配置可以被实际的请求查询参数所覆盖'}
placeholder={'请输入默认 API 版本例如2024-03-01-preview该配置可以被实际的请求查询参数所覆盖'}
onChange={handleInputChange}
value={inputs.other}
autoComplete='new-password'
@@ -355,7 +356,7 @@ const EditChannel = () => {
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: basicModels });
}}>填入基础模型</Button>
}}>填入相关模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: fullModels });
}}>填入所有模型</Button>

View File

@@ -1,19 +1,22 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
import { useNavigate, useParams } from 'react-router-dom';
import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuotaWithPrompt } from '../../helpers/render';
const EditToken = () => {
const params = useParams();
const tokenId = params.id;
const isEdit = tokenId !== undefined;
const [loading, setLoading] = useState(isEdit);
const [modelOptions, setModelOptions] = useState([]);
const originInputs = {
name: '',
remain_quota: isEdit ? 0 : 500000,
expired_time: -1,
unlimited_quota: false
unlimited_quota: false,
models: [],
subnet: "",
};
const [inputs, setInputs] = useState(originInputs);
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
@@ -22,8 +25,8 @@ const EditToken = () => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const handleCancel = () => {
navigate("/token");
}
navigate('/token');
};
const setExpiredTime = (month, day, hour, minute) => {
let now = new Date();
let timestamp = now.getTime() / 1000;
@@ -50,6 +53,11 @@ const EditToken = () => {
if (data.expired_time !== -1) {
data.expired_time = timestamp2string(data.expired_time);
}
if (data.models === '') {
data.models = [];
} else {
data.models = data.models.split(',');
}
setInputs(data);
} else {
showError(message);
@@ -60,8 +68,26 @@ const EditToken = () => {
if (isEdit) {
loadToken().then();
}
loadAvailableModels().then();
}, []);
const loadAvailableModels = async () => {
let res = await API.get(`/api/user/available_models`);
const { success, message, data } = res.data;
if (success) {
let options = data.map((model) => {
return {
key: model,
text: model,
value: model
};
});
setModelOptions(options);
} else {
showError(message);
}
};
const submit = async () => {
if (!isEdit && inputs.name === '') return;
let localInputs = inputs;
@@ -74,6 +100,7 @@ const EditToken = () => {
}
localInputs.expired_time = Math.ceil(time / 1000);
}
localInputs.models = localInputs.models.join(',');
let res;
if (isEdit) {
res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) });
@@ -109,6 +136,34 @@ const EditToken = () => {
required={!isEdit}
/>
</Form.Field>
<Form.Field>
<Form.Dropdown
label='模型范围'
placeholder={'请选择允许使用的模型,留空则不进行限制'}
name='models'
fluid
multiple
search
onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
<Form.Field>
<Form.Input
label='IP 限制'
name='subnet'
placeholder={'请输入允许访问的网段例如192.168.0.0/24'}
onChange={handleInputChange}
value={inputs.subnet}
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.Input
label='过期时间'

Some files were not shown because too many files have changed in this diff Show More