Compare commits

...

51 Commits

Author SHA1 Message Date
JustSong
463b0b3c51 fix: no need to check turnstile when process deletion 2023-08-06 22:28:07 +08:00
JustSong
c3d85a28d4 chore: update i18n 2023-08-06 22:09:05 +08:00
JustSong
7422b0d051 chore: update prompt 2023-08-06 22:07:31 +08:00
Yolo°
5a62357c93 feat: add chat button for each token (#363)
* fork

* fork

* chore: update style

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-08-06 22:02:58 +08:00
JustSong
b464e2907a chore: update domain 2023-08-06 18:14:13 +08:00
JustSong
d96cf2e84d fix: fix stream mode determine related logic (close #360) 2023-08-06 18:09:00 +08:00
glzjin
446337c329 fix: calculate usage if not given in non-stream mode (#352) 2023-08-06 17:40:31 +08:00
Miniers
1dfa190e79 feat: able to copy scheme of ama, opencat & chatgpt next web (#343)
* Token Adds Option to Quickly Copy AMA and OpenCat URL Scheme

* feat: add ChatGPT Next Web

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-08-06 13:56:59 +08:00
glzjin
2d49ca6a07 fix: fix SparkDesk not billed (#344) 2023-08-06 13:24:49 +08:00
a497625414
89bcaaf989 docs: update readme (#359)
* update-deploy-on-sealos

* update-deploy-on-sealos
2023-08-06 13:19:54 +08:00
a497625414
afcd1bd27b docs: update deploy-on-sealos (#351) 2023-08-02 19:11:21 +08:00
glzjin
c2c455c980 fix: fix zhipu streaming (#349)
* Fix #348

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-08-01 23:51:28 +08:00
JustSong
30a7f1a1c7 docs: update README 2023-07-30 22:24:07 +08:00
JustSong
c9d2e42a9e fix: fix sse not ending properly in some case 2023-07-30 22:20:42 +08:00
ckt
3fca6ff534 feat: support email domain whitelist (#337)
* feat: support email domain restriction

* fix(SMTPToken): disable password auto complete

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-07-30 12:44:41 +08:00
JustSong
8cbbeb784f fix: ignore data if not have proper prefix 2023-07-30 12:03:06 +08:00
JustSong
ec88c0c240 fix: prompt user that channel test is unavailable 2023-07-29 23:54:09 +08:00
JustSong
065147b440 fix: close connection when response ended 2023-07-29 23:52:18 +08:00
JustSong
fe8f216dd9 refactor: update billing related code 2023-07-29 22:32:05 +08:00
JustSong
b7d0616ae0 chore: update title for xunfei 2023-07-29 22:09:10 +08:00
JustSong
ce9c8024a6 chore: update prompt for xunfei 2023-07-29 22:05:15 +08:00
JustSong
8a866078b2 feat: support xunfei's llm (close #206) 2023-07-29 21:55:57 +08:00
JustSong
3e81d8af45 chore: update i18n 2023-07-29 19:50:29 +08:00
JustSong
b8cb86c2c1 chore: adjust ui 2023-07-29 19:32:06 +08:00
JustSong
f45d586400 fix: fix model mapping cannot be deleted 2023-07-29 19:17:26 +08:00
JustSong
50dec03ff3 fix: fix model mapping cannot be deleted 2023-07-29 19:16:42 +08:00
JustSong
f31d400b6f chore: automatically add related models when switch type 2023-07-29 12:24:23 +08:00
JustSong
130e6bfd83 feat: support baidu's embedding model (close #324) 2023-07-29 12:15:07 +08:00
JustSong
d1335ebc01 docs: update README 2023-07-28 23:47:36 +08:00
JustSong
e92da7928b feat: support ali's llm (close #326) 2023-07-28 23:45:08 +08:00
JustSong
d1b6f492b6 fix: convert system message to user message for zhipu 2023-07-27 23:32:00 +08:00
JustSong
b9f6461dd4 fix: convert system message to user message for claude 2023-07-27 23:26:56 +08:00
JustSong
0a39521a3d fix: convert system message to user message (close #328) 2023-07-27 23:16:11 +08:00
JustSong
c134604cee fix: use channel type to determine api type (close #321) 2023-07-24 23:34:14 +08:00
mrhaoji
929e43ef81 fix: baseURL not working in APITypePaLM (#317)
* fix: baseURL not working in APITypePaLM

* chore: use the same logic as claude

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-07-24 22:37:57 +08:00
mrhaoji
dce8bbe1ca fix: relay router typo for List models (#320)
via: https://platform.openai.com/docs/api-reference/models/list
2023-07-24 22:28:16 +08:00
JustSong
bc2f48b1f2 chore: send a notice to user indicating that the current channel type does not support testing 2023-07-23 23:23:56 +08:00
JustSong
889af8b2db fix: fix redemption can be used multiple times in some cases 2023-07-23 19:34:23 +08:00
JustSong
4eea096654 chore: do not hardcode cache time (close #302) 2023-07-23 19:26:37 +08:00
JustSong
4ab3211c0e fix: fix blank screen after refresh 2023-07-23 17:59:30 +08:00
JustSong
3da119efba perf: reuse http client to reduce delay 2023-07-23 15:18:58 +08:00
ckt
dccd66b852 feat: add access_until field for subscription api (#295)
* 支持从计费渠道端点返回日期

* fix: fix wrong git base

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-07-23 13:49:09 +08:00
ckt
2fcd6852e0 feat: able to delete account by self (#294)
* feat: support account deletion

* chore: update style

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-07-23 13:37:32 +08:00
Yolo°
9b4d1964d4 chore: optimize frontend (#293)
* main

* chore: update style

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-07-23 13:25:28 +08:00
JustSong
806bf8241c chore: update prompts of channel config page 2023-07-23 12:46:41 +08:00
JustSong
ce93c9b6b2 chore: adjust channel config page 2023-07-23 12:20:42 +08:00
JustSong
4ec4289565 docs: update README 2023-07-23 12:01:49 +08:00
JustSong
3dc5a0f91d docs: update README 2023-07-23 11:56:51 +08:00
JustSong
80a846673a docs: update README 2023-07-23 11:55:33 +08:00
JustSong
26c6719ea3 feat: support zhipu's ChatGLM (close #289) 2023-07-23 11:51:44 +08:00
JustSong
c87e05bfc2 chore: add more context in error message 2023-07-23 09:49:46 +08:00
45 changed files with 2042 additions and 365 deletions

View File

@@ -10,7 +10,7 @@
# One API # One API
_✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_
</div> </div>
@@ -57,15 +57,13 @@ _✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_
> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability.
## Features ## Features
1. Supports multiple API access channels: 1. Support for multiple large models:
+ [x] Official OpenAI channel (support proxy configuration) + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
+ [x] **Azure OpenAI API** + [x] [Anthropic Claude Series Models](https://anthropic.com)
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) + [x] [Google PaLM2 Series Models](https://developers.generativeai.google)
+ [x] [OpenAI-SB](https://openai-sb.com) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [API2D](https://api2d.com/r/197971) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`)
+ [x] Custom channel: Various third-party proxy services not included in the list
2. Supports access to multiple channels through **load balancing**. 2. Supports access to multiple channels through **load balancing**.
3. Supports **stream mode** that enables typewriter-like effect through stream transmission. 3. Supports **stream mode** that enables typewriter-like effect through stream transmission.
4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details.
@@ -139,7 +137,7 @@ The initial account username is `root` and password is `123456`.
cd one-api/web cd one-api/web
npm install npm install
npm run build npm run build
# Build the backend # Build the backend
cd .. cd ..
go mod download go mod download
@@ -175,7 +173,12 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
<summary><strong>Deploy on Sealos</strong></summary> <summary><strong>Deploy on Sealos</strong></summary>
<div> <div>
Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md). > Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users.
> Click the button below to deploy with one click.👇
[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
</div> </div>
</details> </details>

View File

@@ -11,7 +11,7 @@
# One API # One API
_✨ All in one 的 OpenAI 接口,整合各种 API 访问方式开箱即用✨_ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_
</div> </div>
@@ -58,45 +58,45 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。 > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
## 功能 ## 功能
1. 支持多种 API 访问渠道 1. 支持多种大模型
+ [x] OpenAI 官方通道(支持配置镜像 + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] **Azure OpenAI API**
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
2. 支持配置镜像以及众多第三方代理服务:
+ [x] [OpenAI-SB](https://openai-sb.com) + [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [API2D](https://api2d.com/r/197971) + [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI` + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412) + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] 自定义渠道:例如各种未收录的第三方代理服务 + [x] 自定义渠道:例如各种未收录的第三方代理服务
2. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
4. 支持**多机部署**[详见此处](#多机部署)。 5. 支持**多机部署**[详见此处](#多机部署)。
5. 支持**令牌管理**,设置令牌的过期时间和额度。 6. 支持**令牌管理**,设置令牌的过期时间和额度。
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
7. 支持**通道管理**,批量创建通道。 8. 支持**通道管理**,批量创建通道。
8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
9. 支持渠道**设置模型列表**。 10. 支持渠道**设置模型列表**。
10. 支持**查看额度明细**。 11. 支持**查看额度明细**。
11. 支持**用户邀请奖励**。 12. 支持**用户邀请奖励**。
12. 支持以美元为单位显示额度。 13. 支持以美元为单位显示额度。
13. 支持发布公告,设置充值链接,设置新用户初始额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。
14. 支持模型映射,重定向用户的请求模型。 15. 支持模型映射,重定向用户的请求模型。
15. 支持失败自动重试。 16. 支持失败自动重试。
16. 支持绘图接口。 17. 支持绘图接口。
17. 支持丰富的**自定义**设置, 18. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
18. 支持通过系统访问令牌访问管理 API。 19. 支持通过系统访问令牌访问管理 API。
19. 支持 Cloudflare Turnstile 用户校验。 20. 支持 Cloudflare Turnstile 用户校验。
20. 支持用户管理,支持**多种用户登录注册方式** 21. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册以及通过邮箱进行密码重置。 + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。
22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
@@ -153,7 +153,7 @@ sudo service nginx restart
cd one-api/web cd one-api/web
npm install npm install
npm run build npm run build
# 构建后端 # 构建后端
cd .. cd ..
go mod download go mod download
@@ -211,9 +211,11 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
<summary><strong>部署到 Sealos </strong></summary> <summary><strong>部署到 Sealos </strong></summary>
<div> <div>
> Sealos 可视化部署,仅需 1 分钟 > Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩
参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。 点击以下按钮一键部署:
[![Deploy-on-Sealos.svg](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
</div> </div>
</details> </details>
@@ -314,6 +316,7 @@ https://openai.justsong.cn
+ 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率)
+ 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。 + 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。
+ 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。 + 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。
+ 注意One API 的默认倍率就是官方倍率,是已经调整过的。
2. 账户额度足够为什么提示额度不足? 2. 账户额度足够为什么提示额度不足?
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。 + 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。 + 令牌额度仅供用户设置最大使用量,用户可自由设置。

View File

@@ -42,6 +42,19 @@ var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var LogConsumeEnabled = true var LogConsumeEnabled = true
var SMTPServer = "" var SMTPServer = ""
@@ -77,6 +90,8 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1
@@ -153,23 +168,29 @@ const (
ChannelTypeAIGC2D = 13 ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14 ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15 ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
"", // 0 "", // 0
"https://api.openai.com", // 1 "https://api.openai.com", // 1
"https://oa.api2d.net", // 2 "https://oa.api2d.net", // 2
"", // 3 "", // 3
"https://api.closeai-proxy.xyz", // 4 "https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5 "https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6 "https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7 "https://api.ohmygpt.com", // 7
"", // 8 "", // 8
"https://api.caipacity.com", // 9 "https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10 "https://api.aiproxy.io", // 10
"", // 11 "", // 11
"https://api.api2gpt.com", // 12 "https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13 "https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14 "https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15 "https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
} }

View File

@@ -8,6 +8,7 @@ import "encoding/json"
// https://openai.com/pricing // https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here // TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens // 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{ var ModelRatio = map[string]float64{
"gpt-4": 15, "gpt-4": 15,
"gpt-4-0314": 15, "gpt-4-0314": 15,
@@ -39,9 +40,16 @@ var ModelRatio = map[string]float64{
"dall-e": 8, "dall-e": 8,
"claude-instant-1": 0.75, "claude-instant-1": 0.75,
"claude-2": 30, "claude-2": 30,
"ERNIE-Bot": 1, // 0.012元/千tokens "ERNIE-Bot": 0.8572, // 0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.67, // 0.008元/千tokens "ERNIE-Bot-turbo": 0.5715, // 0.008 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
"qwen-plus-v1": 0.5715, // Same as above
"SparkDesk": 0.8572, // TBD
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {

View File

@@ -11,9 +11,11 @@ func GetSubscription(c *gin.Context) {
var usedQuota int var usedQuota int
var err error var err error
var token *model.Token var token *model.Token
var expiredTime int64
if common.DisplayTokenStatEnabled { if common.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota remainQuota = token.RemainQuota
usedQuota = token.UsedQuota usedQuota = token.UsedQuota
} else { } else {
@@ -21,6 +23,9 @@ func GetSubscription(c *gin.Context) {
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId)
usedQuota, err = model.GetUserUsedQuota(userId) usedQuota, err = model.GetUserUsedQuota(userId)
} }
if expiredTime <= 0 {
expiredTime = 0
}
if err != nil { if err != nil {
openAIError := OpenAIError{ openAIError := OpenAIError{
Message: err.Error(), Message: err.Error(),
@@ -45,6 +50,7 @@ func GetSubscription(c *gin.Context) {
SoftLimitUSD: amount, SoftLimitUSD: amount,
HardLimitUSD: amount, HardLimitUSD: amount,
SystemHardLimitUSD: amount, SystemHardLimitUSD: amount,
AccessUntil: expiredTime,
} }
c.JSON(200, subscription) c.JSON(200, subscription)
return return

View File

@@ -22,6 +22,7 @@ type OpenAISubscriptionResponse struct {
SoftLimitUSD float64 `json:"soft_limit_usd"` SoftLimitUSD float64 `json:"soft_limit_usd"`
HardLimitUSD float64 `json:"hard_limit_usd"` HardLimitUSD float64 `json:"hard_limit_usd"`
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
AccessUntil int64 `json:"access_until"`
} }
type OpenAIUsageDailyCost struct { type OpenAIUsageDailyCost struct {
@@ -84,7 +85,6 @@ func GetAuthHeader(token string) http.Header {
} }
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(method, url, nil) req, err := http.NewRequest(method, url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -92,10 +92,13 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := client.Do(req) res, err := httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -16,6 +16,16 @@ import (
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
switch channel.Type { switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo" request.Model = "gpt-35-turbo"
default: default:
@@ -45,8 +55,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
req.Header.Set("Authorization", "Bearer "+channel.Key) req.Header.Set("Authorization", "Bearer "+channel.Key)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
client := &http.Client{} resp, err := httpClient.Do(req)
resp, err := client.Do(req)
if err != nil { if err != nil {
return err, nil return err, nil
} }

View File

@@ -3,10 +3,12 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strings"
"github.com/gin-gonic/gin"
) )
func GetStatus(c *gin.Context) { func GetStatus(c *gin.Context) {
@@ -78,6 +80,22 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
if common.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
}
}
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
})
return
}
}
if model.IsEmailAlreadyTaken(email) { if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -127,8 +145,9 @@ func SendPasswordResetEmail(c *gin.Context) {
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击<a href='%s'>此处</a>进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes) "<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -252,24 +252,6 @@ func init() {
Root: "code-davinci-edit-001", Root: "code-davinci-edit-001",
Parent: nil, Parent: nil,
}, },
{
Id: "ChatGLM",
Object: "model",
Created: 1677649963,
OwnedBy: "thudm",
Permission: permission,
Root: "ChatGLM",
Parent: nil,
},
{
Id: "ChatGLM2",
Object: "model",
Created: 1677649963,
OwnedBy: "thudm",
Permission: permission,
Root: "ChatGLM2",
Parent: nil,
},
{ {
Id: "claude-instant-1", Id: "claude-instant-1",
Object: "model", Object: "model",
@@ -306,6 +288,15 @@ func init() {
Root: "ERNIE-Bot-turbo", Root: "ERNIE-Bot-turbo",
Parent: nil, Parent: nil,
}, },
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{ {
Id: "PaLM-2", Id: "PaLM-2",
Object: "model", Object: "model",
@@ -315,6 +306,60 @@ func init() {
Root: "PaLM-2", Root: "PaLM-2",
Parent: nil, Parent: nil,
}, },
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-v1",
Parent: nil,
},
{
Id: "qwen-plus-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {

View File

@@ -2,11 +2,12 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func GetOptions(c *gin.Context) { func GetOptions(c *gin.Context) {
@@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
})
return
}
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" { if option.Value == "true" && common.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

240
controller/relay-ali.go Normal file
View File

@@ -0,0 +1,240 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
choice.FinishReason = aliResponse.Output.FinishReason
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
usage.PromptTokens += aliResponse.Usage.InputTokens
usage.CompletionTokens += aliResponse.Usage.OutputTokens
usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -54,13 +54,43 @@ type BaiduChatStreamResponse struct {
IsEnd bool `json:"is_end"` IsEnd bool `json:"is_end"`
} }
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages)) messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
messages = append(messages, BaiduMessage{ if message.Role == "system" {
Role: message.Role, messages = append(messages, BaiduMessage{
Content: message.Content, Role: "user",
}) Content: message.Content,
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.Content,
})
}
} }
return &BaiduChatRequest{ return &BaiduChatRequest{
Messages: messages, Messages: messages,
@@ -101,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
return &response return &response
} }
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
baiduEmbeddingRequest := BaiduEmbeddingRequest{
Input: nil,
}
switch request.Input.(type) {
case string:
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
case []string:
baiduEmbeddingRequest.Input = request.Input.([]string)
}
return &baiduEmbeddingRequest
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage var usage Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
@@ -201,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" { } else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else { } else if message.Role == "system" {
// ignore other roles prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
} }
prompt += "\n\nAssistant:"
} }
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt claudeRequest.Prompt = prompt
return &claudeRequest return &claudeRequest
} }

View File

@@ -109,8 +109,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
client := &http.Client{} resp, err := httpClient.Do(req)
resp, err := client.Do(req)
if err != nil { if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }

View File

@@ -34,6 +34,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 { // ignore blank line or wrong format
continue continue
} }
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data dataChan <- data
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
@@ -43,7 +46,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return continue // just ignore the error
} }
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content responseText += choice.Delta.Content
@@ -53,7 +56,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return continue
} }
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += choice.Text responseText += choice.Text
@@ -89,7 +92,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText return nil, responseText
} }
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse var textResponse TextResponse
if consumeQuota { if consumeQuota {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
@@ -115,7 +118,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope
} }
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the client will be confused by the response. // So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all. // For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
@@ -129,5 +132,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
return nil, &textResponse.Usage return nil, &textResponse.Usage
} }

View File

@@ -19,8 +19,17 @@ const (
APITypeClaude APITypeClaude
APITypePaLM APITypePaLM
APITypeBaidu APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
) )
var httpClient *http.Client
func init() {
httpClient = &http.Client{}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
@@ -66,7 +75,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
isModelMapped := false isModelMapped := false
if modelMapping != "" { if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
@@ -78,12 +87,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
} }
apiType := APITypeOpenAI apiType := APITypeOpenAI
if strings.HasPrefix(textRequest.Model, "claude") { switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude apiType = APITypeClaude
} else if strings.HasPrefix(textRequest.Model, "ERNIE") { case common.ChannelTypeBaidu:
apiType = APITypeBaidu apiType = APITypeBaidu
} else if strings.HasPrefix(textRequest.Model, "PaLM") { case common.ChannelTypePaLM:
apiType = APITypePaLM apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
@@ -125,15 +141,28 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B": case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
} }
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
case APITypePaLM: case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
} }
var promptTokens int var promptTokens int
var completionTokens int var completionTokens int
@@ -187,12 +216,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu: case APITypeBaidu:
baiduRequest := requestOpenAI2Baidu(textRequest) var jsonData []byte
jsonStr, err := json.Marshal(baiduRequest) var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil { if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM: case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest) palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest) jsonStr, err := json.Marshal(palmRequest)
@@ -200,47 +237,75 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
} case APITypeZhipu:
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) zhipuRequest := requestOpenAI2Zhipu(textRequest)
if err != nil { jsonStr, err := json.Marshal(zhipuRequest)
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) if err != nil {
} return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
} }
case APITypeClaude: requestBody = bytes.NewBuffer(jsonStr)
req.Header.Set("x-api-key", apiKey) case APITypeAli:
anthropicVersion := c.Request.Header.Get("anthropic-version") aliRequest := requestOpenAI2Ali(textRequest)
if anthropicVersion == "" { jsonStr, err := json.Marshal(aliRequest)
anthropicVersion = "2023-06-01" if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
req.Header.Set("anthropic-version", anthropicVersion) requestBody = bytes.NewBuffer(jsonStr)
} }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) var req *http.Request
//req.Header.Set("Connection", c.Request.Header.Get("Connection")) var resp *http.Response
client := &http.Client{} isStream := textRequest.Stream
resp, err := client.Do(req)
if err != nil { if apiType != APITypeXunfei { // cause xunfei use websocket
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
} if err != nil {
err = req.Body.Close() return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
if err != nil { }
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) apiKey := c.Request.Header.Get("Authorization")
} apiKey = strings.TrimPrefix(apiKey, "Bearer ")
err = c.Request.Body.Close() switch apiType {
if err != nil { case APITypeOpenAI:
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
} }
var textResponse TextResponse var textResponse TextResponse
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
var streamResponseText string
defer func() { defer func() {
if consumeQuota { if consumeQuota {
@@ -252,12 +317,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if strings.HasPrefix(textRequest.Model, "gpt-4") { if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2 completionRatio = 2
} }
if isStream && apiType != APITypeBaidu {
completionTokens = countTokenText(streamResponseText, textRequest.Model) promptTokens = textResponse.Usage.PromptTokens
} else { completionTokens = textResponse.Usage.CompletionTokens
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
}
quota = promptTokens + int(float64(completionTokens)*completionRatio) quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio) quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {
@@ -295,14 +358,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil { if err != nil {
return err return err
} }
streamResponseText = responseText textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil return nil
} else { } else {
err, usage := openaiHandler(c, resp, consumeQuota) err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
if err != nil { if err != nil {
return err return err
} }
textResponse.Usage = *usage if usage != nil {
textResponse.Usage = *usage
}
return nil return nil
} }
case APITypeClaude: case APITypeClaude:
@@ -311,14 +377,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil { if err != nil {
return err return err
} }
streamResponseText = responseText textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil return nil
} else { } else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil { if err != nil {
return err return err
} }
textResponse.Usage = *usage if usage != nil {
textResponse.Usage = *usage
}
return nil return nil
} }
case APITypeBaidu: case APITypeBaidu:
@@ -327,14 +396,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil { if err != nil {
return err return err
} }
textResponse.Usage = *usage if usage != nil {
textResponse.Usage = *usage
}
return nil return nil
} else { } else {
err, usage := baiduHandler(c, resp) var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil { if err != nil {
return err return err
} }
textResponse.Usage = *usage if usage != nil {
textResponse.Usage = *usage
}
return nil return nil
} }
case APITypePaLM: case APITypePaLM:
@@ -343,16 +423,82 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil { if err != nil {
return err return err
} }
streamResponseText = responseText textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil return nil
} else { } else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil { if err != nil {
return err return err
} }
textResponse.Usage = *usage if usage != nil {
textResponse.Usage = *usage
}
return nil return nil
} }
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aliHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
if isStream {
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
}
default: default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
} }

278
controller/relay-xunfei.go Normal file
View File

@@ -0,0 +1,278 @@
package controller
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"net/url"
"one-api/common"
"strings"
"time"
)
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.Content,
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.Content,
})
}
}
xunfeiRequest := XunfeiChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId
xunfeiRequest.Parameter.Chat.Domain = "general"
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
}
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: response.Payload.Usage.Text,
}
return &fullTextResponse
}
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
HmacWithShaToBase64 := func(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
ul, err := url.Parse(hostUrl)
if err != nil {
fmt.Println(err)
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
sign := strings.Join(signString, "\n")
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
callUrl := hostUrl + "?" + v.Encode()
return callUrl
}
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
hostUrl := "wss://aichat.xf-yun.com/v1/chat"
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
}
data := requestOpenAI2Xunfei(textRequest, appId)
err = conn.WriteJSON(data)
if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var xunfeiResponse XunfeiChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &xunfeiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if xunfeiResponse.Header.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: xunfeiResponse.Header.Message,
Type: "xunfei_error",
Param: "",
Code: xunfeiResponse.Header.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

306
controller/relay-zhipu.go Normal file
View File

@@ -0,0 +1,306 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
// chatglm_std, chatglm_lite
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
Usage `json:"usage"`
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
func getZhipuToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey)
if ok {
tokenData := data.(zhipuTokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
}
split := strings.Split(apikey, ".")
if len(split) != 2 {
common.SysError("invalid zhipu key: " + apikey)
return ""
}
id := split[0]
secret := split[1]
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
timestamp := time.Now().UnixNano() / 1e6
payload := jwt.MapClaims{
"api_key": id,
"exp": expMillis,
"timestamp": timestamp,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
zhipuTokens.Store(apikey, zhipuTokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
return tokenString
}
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
Content: message.Content,
})
messages = append(messages, ZhipuMessage{
Role: "user",
Content: "Okay",
})
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
Content: message.Content,
})
}
}
return &ZhipuRequest{
Prompt: messages,
Temperature: request.Temperature,
TopP: request.TopP,
Incremental: false,
}
}
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: response.Data.TaskId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
openaiChoice := OpenAITextResponseChoice{
Index: i,
Message: Message{
Role: choice.Role,
Content: strings.Trim(choice.Content, "\""),
},
FinishReason: "",
}
if i == len(response.Data.Choices)-1 {
openaiChoice.FinishReason = "stop"
}
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
}
return &fullTextResponse
}
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse
choice.FinishReason = ""
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.FinishReason = "stop"
response := ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage *Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
lines := strings.Split(data, "\n")
for i, line := range lines {
if len(line) < 5 {
continue
}
if line[:5] == "data:" {
dataChan <- line[5:]
if i != len(lines)-1 {
dataChan <- "\n"
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
}
}
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case data := <-metaChan:
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
Code: zhipuResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -81,8 +81,9 @@ type OpenAIErrorWithStatusCode struct {
} }
type TextResponse struct { type TextResponse struct {
Usage `json:"usage"` Choices []OpenAITextResponseChoice `json:"choices"`
Error OpenAIError `json:"error"` Usage `json:"usage"`
Error OpenAIError `json:"error"`
} }
type OpenAITextResponseChoice struct { type OpenAITextResponseChoice struct {
@@ -99,6 +100,19 @@ type OpenAITextResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
} }
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct { type ImageResponse struct {
Created int `json:"created"` Created int `json:"created"`
Data []struct { Data []struct {

View File

@@ -3,12 +3,13 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
) )
type LoginRequest struct { type LoginRequest struct {
@@ -477,6 +478,16 @@ func DeleteUser(c *gin.Context) {
func DeleteSelf(c *gin.Context) { func DeleteSelf(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
if user.Role == common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除超级管理员账户",
})
return
}
err := model.DeleteUserById(id) err := model.DeleteUserById(id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

4
go.mod
View File

@@ -11,7 +11,9 @@ require (
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0 github.com/go-playground/validator/v10 v10.14.0
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.1 github.com/pkoukk/tiktoken-go v0.1.1
golang.org/x/crypto v0.9.0 golang.org/x/crypto v0.9.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
@@ -20,7 +22,6 @@ require (
) )
require ( require (
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@@ -32,7 +33,6 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/gomodule/redigo v2.0.0+incompatible // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect

9
go.sum
View File

@@ -1,5 +1,3 @@
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04=
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -54,10 +52,10 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -67,9 +65,10 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=

View File

@@ -3,6 +3,11 @@
"%d 点额度": "%d point quota", "%d 点额度": "%d point quota",
"尚未实现": "Not yet implemented", "尚未实现": "Not yet implemented",
"余额不足": "Insufficient balance", "余额不足": "Insufficient balance",
"危险操作": "Hazardous operations",
"输入你的账户名": "Enter your account name",
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"通道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"", "\"通道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"通道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s", "通道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running", "测试已在运行中": "Test is already running",
@@ -427,7 +432,7 @@
"一分钟后过期": "Expires after one minute", "一分钟后过期": "Expires after one minute",
"创建新的令牌": "Create New Token", "创建新的令牌": "Create New Token",
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
"设为无限额度": "Set to unlimited quota", "设为无限额度": "Set to unlimited quota",
"更新令牌信息": "Update Token Information", "更新令牌信息": "Update Token Information",
"请输入充值码!": "Please enter the recharge code!", "请输入充值码!": "Please enter the recharge code!",
"请输入名称": "Please enter a name", "请输入名称": "Please enter a name",
@@ -493,6 +498,7 @@
"参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)", "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)",
"模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
"取消无限额度": "Cancel unlimited quota", "取消无限额度": "Cancel unlimited quota",
"取消": "Cancel",
"请输入新的剩余额度": "Please enter the new remaining quota", "请输入新的剩余额度": "Please enter the new remaining quota",
"请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code", "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code",
"请输入用户名": "Please enter username", "请输入用户名": "Please enter username",
@@ -503,5 +509,14 @@
"请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT",
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL", "Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL" "Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为通道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
"注意,": "Note that, ",
",图片演示。": "related image demo.",
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
"代理": "Proxy",
"此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com"
} }

View File

@@ -54,6 +54,7 @@ func main() {
if err != nil { if err != nil {
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
} }
common.SyncFrequency = frequency
go model.SyncOptions(frequency) go model.SyncOptions(frequency)
if common.RedisEnabled { if common.RedisEnabled {
go model.SyncChannelCache(frequency) go model.SyncChannelCache(frequency)

View File

@@ -86,7 +86,7 @@ func Distribute() func(c *gin.Context) {
} }
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
message := "无可用渠道" message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil { if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员" message = "数据库一致性已被破坏,请联系管理员"

View File

@@ -12,11 +12,11 @@ import (
"time" "time"
) )
const ( var (
TokenCacheSeconds = 60 * 60 TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = 60 * 60 UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = 10 * 60 UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = 60 * 60 UserId2StatusCacheSeconds = common.SyncFrequency
) )
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
@@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second) err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set token error: " + err.Error()) common.SysError("Redis set token error: " + err.Error())
} }
@@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
if err != nil { if err != nil {
return "", err return "", err
} }
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second) err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user group error: " + err.Error()) common.SysError("Redis set user group error: " + err.Error())
} }
@@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user quota error: " + err.Error()) common.SysError("Redis set user quota error: " + err.Error())
} }
@@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error {
if err != nil { if err != nil {
return err return err
} }
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
return err return err
} }
@@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool {
status = common.UserStatusEnabled status = common.UserStatusEnabled
} }
enabled = fmt.Sprintf("%d", status) enabled = fmt.Sprintf("%d", status)
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second) err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user enabled error: " + err.Error()) common.SysError("Redis set user enabled error: " + err.Error())
} }

View File

@@ -39,6 +39,8 @@ func InitOptionMap() {
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = "" common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = "" common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
@@ -141,6 +143,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue common.AutomaticDisableChannelEnabled = boolValue
case "ApproximateTokenEnabled": case "ApproximateTokenEnabled":
@@ -154,6 +158,8 @@ func updateOptionMap(key string, value string) (err error) {
} }
} }
switch key { switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer": case "SMTPServer":
common.SMTPServer = value common.SMTPServer = value
case "SMTPPort": case "SMTPPort":

View File

@@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) {
redemption := &Redemption{} redemption := &Redemption{}
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := DB.Where("`key` = ?", key).First(redemption).Error err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }
if redemption.Status != common.RedemptionCodeStatusEnabled { if redemption.Status != common.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用") return errors.New("该兑换码已被使用")
} }
err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
if err != nil { if err != nil {
return err return err
} }
redemption.RedeemedTime = common.GetTimestamp() redemption.RedeemedTime = common.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = common.RedemptionCodeStatusUsed
return redemption.SelectUpdate() err = tx.Save(redemption).Error
return err
}) })
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())

View File

@@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter := router.Group("/v1/models") modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth()) modelsRouter.Use(middleware.TokenAuth())
{ {
modelsRouter.GET("/", controller.ListModels) modelsRouter.GET("", controller.ListModels)
modelsRouter.GET("/:model", controller.RetrieveModel) modelsRouter.GET("/:model", controller.RetrieveModel)
} }
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")

View File

@@ -363,9 +363,12 @@ const ChannelsTable = () => {
</Table.Cell> </Table.Cell>
<Table.Cell> <Table.Cell>
<Popup <Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} trigger={<span onClick={() => {
key={channel.id} updateChannelBalance(channel.id, channel.name, idx);
trigger={renderBalance(channel.type, channel.balance)} }} style={{ cursor: 'pointer' }}>
{renderBalance(channel.type, channel.balance)}
</span>}
content="点击更新"
basic basic
/> />
</Table.Cell> </Table.Cell>
@@ -380,16 +383,16 @@ const ChannelsTable = () => {
> >
测试 测试
</Button> </Button>
<Button {/*<Button*/}
size={'small'} {/* size={'small'}*/}
positive {/* positive*/}
loading={updatingBalance} {/* loading={updatingBalance}*/}
onClick={() => { {/* onClick={() => {*/}
updateChannelBalance(channel.id, channel.name, idx); {/* updateChannelBalance(channel.id, channel.name, idx);*/}
}} {/* }}*/}
> {/*>*/}
更新余额 {/* 更新余额*/}
</Button> {/*</Button>*/}
<Popup <Popup
trigger={ trigger={
<Button size='small' negative> <Button size='small' negative>

View File

@@ -1,36 +1,25 @@
import React, { useContext, useEffect, useState } from 'react'; import React, { useContext, useEffect, useState } from 'react';
import { import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
Button,
Divider,
Form,
Grid,
Header,
Image,
Message,
Modal,
Segment,
} from 'semantic-ui-react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User'; import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess, showInfo } from '../helpers'; import { API, getLogo, showError, showSuccess } from '../helpers';
const LoginForm = () => { const LoginForm = () => {
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
username: '', username: '',
password: '', password: '',
wechat_verification_code: '', wechat_verification_code: ''
}); });
const [searchParams, setSearchParams] = useSearchParams(); const [searchParams, setSearchParams] = useSearchParams();
const [submitted, setSubmitted] = useState(false); const [submitted, setSubmitted] = useState(false);
const { username, password } = inputs; const { username, password } = inputs;
const [userState, userDispatch] = useContext(UserContext); const [userState, userDispatch] = useContext(UserContext);
let navigate = useNavigate(); let navigate = useNavigate();
const [status, setStatus] = useState({}); const [status, setStatus] = useState({});
const logo = getLogo(); const logo = getLogo();
useEffect(() => { useEffect(() => {
if (searchParams.get("expired")) { if (searchParams.get('expired')) {
showError('未登录或登录已过期,请重新登录!'); showError('未登录或登录已过期,请重新登录!');
} }
let status = localStorage.getItem('status'); let status = localStorage.getItem('status');
@@ -78,7 +67,7 @@ const LoginForm = () => {
if (username && password) { if (username && password) {
const res = await API.post(`/api/user/login`, { const res = await API.post(`/api/user/login`, {
username, username,
password, password
}); });
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
@@ -93,44 +82,44 @@ const LoginForm = () => {
} }
return ( return (
<Grid textAlign="center" style={{ marginTop: '48px' }}> <Grid textAlign='center' style={{ marginTop: '48px' }}>
<Grid.Column style={{ maxWidth: 450 }}> <Grid.Column style={{ maxWidth: 450 }}>
<Header as="h2" color="" textAlign="center"> <Header as='h2' color='' textAlign='center'>
<Image src={logo} /> 用户登录 <Image src={logo} /> 用户登录
</Header> </Header>
<Form size="large"> <Form size='large'>
<Segment> <Segment>
<Form.Input <Form.Input
fluid fluid
icon="user" icon='user'
iconPosition="left" iconPosition='left'
placeholder="用户名" placeholder='用户名'
name="username" name='username'
value={username} value={username}
onChange={handleChange} onChange={handleChange}
/> />
<Form.Input <Form.Input
fluid fluid
icon="lock" icon='lock'
iconPosition="left" iconPosition='left'
placeholder="密码" placeholder='密码'
name="password" name='password'
type="password" type='password'
value={password} value={password}
onChange={handleChange} onChange={handleChange}
/> />
<Button color="" fluid size="large" onClick={handleSubmit}> <Button color='green' fluid size='large' onClick={handleSubmit}>
登录 登录
</Button> </Button>
</Segment> </Segment>
</Form> </Form>
<Message> <Message>
忘记密码 忘记密码
<Link to="/reset" className="btn btn-link"> <Link to='/reset' className='btn btn-link'>
点击重置 点击重置
</Link> </Link>
没有账户 没有账户
<Link to="/register" className="btn btn-link"> <Link to='/register' className='btn btn-link'>
点击注册 点击注册
</Link> </Link>
</Message> </Message>
@@ -140,8 +129,8 @@ const LoginForm = () => {
{status.github_oauth ? ( {status.github_oauth ? (
<Button <Button
circular circular
color="black" color='black'
icon="github" icon='github'
onClick={onGitHubOAuthClicked} onClick={onGitHubOAuthClicked}
/> />
) : ( ) : (
@@ -150,8 +139,8 @@ const LoginForm = () => {
{status.wechat_login ? ( {status.wechat_login ? (
<Button <Button
circular circular
color="green" color='green'
icon="wechat" icon='wechat'
onClick={onWeChatLoginClicked} onClick={onWeChatLoginClicked}
/> />
) : ( ) : (
@@ -175,18 +164,18 @@ const LoginForm = () => {
微信扫码关注公众号输入验证码获取验证码三分钟内有效 微信扫码关注公众号输入验证码获取验证码三分钟内有效
</p> </p>
</div> </div>
<Form size="large"> <Form size='large'>
<Form.Input <Form.Input
fluid fluid
placeholder="验证码" placeholder='验证码'
name="wechat_verification_code" name='wechat_verification_code'
value={inputs.wechat_verification_code} value={inputs.wechat_verification_code}
onChange={handleChange} onChange={handleChange}
/> />
<Button <Button
color="" color=''
fluid fluid
size="large" size='large'
onClick={onSubmitWeChatVerificationCode} onClick={onSubmitWeChatVerificationCode}
> >
登录 登录

View File

@@ -12,6 +12,11 @@ const PasswordResetConfirm = () => {
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [disableButton, setDisableButton] = useState(false);
const [countdown, setCountdown] = useState(30);
const [newPassword, setNewPassword] = useState('');
const [searchParams, setSearchParams] = useSearchParams(); const [searchParams, setSearchParams] = useSearchParams();
useEffect(() => { useEffect(() => {
let token = searchParams.get('token'); let token = searchParams.get('token');
@@ -22,7 +27,21 @@ const PasswordResetConfirm = () => {
}); });
}, []); }, []);
useEffect(() => {
let countdownInterval = null;
if (disableButton && countdown > 0) {
countdownInterval = setInterval(() => {
setCountdown(countdown - 1);
}, 1000);
} else if (countdown === 0) {
setDisableButton(false);
setCountdown(30);
}
return () => clearInterval(countdownInterval);
}, [disableButton, countdown]);
async function handleSubmit(e) { async function handleSubmit(e) {
setDisableButton(true);
if (!email) return; if (!email) return;
setLoading(true); setLoading(true);
const res = await API.post(`/api/user/reset`, { const res = await API.post(`/api/user/reset`, {
@@ -32,14 +51,15 @@ const PasswordResetConfirm = () => {
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
let password = res.data.data; let password = res.data.data;
setNewPassword(password);
await copy(password); await copy(password);
showNotice(`密码已重置并已复制到剪贴板:${password}`); showNotice(`密码已复制到剪贴板:${password}`);
} else { } else {
showError(message); showError(message);
} }
setLoading(false); setLoading(false);
} }
return ( return (
<Grid textAlign='center' style={{ marginTop: '48px' }}> <Grid textAlign='center' style={{ marginTop: '48px' }}>
<Grid.Column style={{ maxWidth: 450 }}> <Grid.Column style={{ maxWidth: 450 }}>
@@ -57,20 +77,37 @@ const PasswordResetConfirm = () => {
value={email} value={email}
readOnly readOnly
/> />
{newPassword && (
<Form.Input
fluid
icon='lock'
iconPosition='left'
placeholder='新密码'
name='newPassword'
value={newPassword}
readOnly
onClick={(e) => {
e.target.select();
navigator.clipboard.writeText(newPassword);
showNotice(`密码已复制到剪贴板:${newPassword}`);
}}
/>
)}
<Button <Button
color='' color='green'
fluid fluid
size='large' size='large'
onClick={handleSubmit} onClick={handleSubmit}
loading={loading} loading={loading}
disabled={disableButton}
> >
提交 {disableButton ? `密码重置完成` : '提交'}
</Button> </Button>
</Segment> </Segment>
</Form> </Form>
</Grid.Column> </Grid.Column>
</Grid> </Grid>
); );
}; };
export default PasswordResetConfirm; export default PasswordResetConfirm;

View File

@@ -5,7 +5,7 @@ import Turnstile from 'react-turnstile';
const PasswordResetForm = () => { const PasswordResetForm = () => {
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
email: '', email: ''
}); });
const { email } = inputs; const { email } = inputs;
@@ -13,24 +13,29 @@ const PasswordResetForm = () => {
const [turnstileEnabled, setTurnstileEnabled] = useState(false); const [turnstileEnabled, setTurnstileEnabled] = useState(false);
const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState(''); const [turnstileToken, setTurnstileToken] = useState('');
const [disableButton, setDisableButton] = useState(false);
const [countdown, setCountdown] = useState(30);
useEffect(() => { useEffect(() => {
let status = localStorage.getItem('status'); let countdownInterval = null;
if (status) { if (disableButton && countdown > 0) {
status = JSON.parse(status); countdownInterval = setInterval(() => {
if (status.turnstile_check) { setCountdown(countdown - 1);
setTurnstileEnabled(true); }, 1000);
setTurnstileSiteKey(status.turnstile_site_key); } else if (countdown === 0) {
} setDisableButton(false);
setCountdown(30);
} }
}, []); return () => clearInterval(countdownInterval);
}, [disableButton, countdown]);
function handleChange(e) { function handleChange(e) {
const { name, value } = e.target; const { name, value } = e.target;
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs(inputs => ({ ...inputs, [name]: value }));
} }
async function handleSubmit(e) { async function handleSubmit(e) {
setDisableButton(true);
if (!email) return; if (!email) return;
if (turnstileEnabled && turnstileToken === '') { if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!'); showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
@@ -78,13 +83,14 @@ const PasswordResetForm = () => {
<></> <></>
)} )}
<Button <Button
color='' color='green'
fluid fluid
size='large' size='large'
onClick={handleSubmit} onClick={handleSubmit}
loading={loading} loading={loading}
disabled={disableButton}
> >
提交 {disableButton ? `重试 (${countdown})` : '提交'}
</Button> </Button>
</Segment> </Segment>
</Form> </Form>

View File

@@ -1,22 +1,32 @@
import React, { useEffect, useState } from 'react'; import React, { useContext, useEffect, useState } from 'react';
import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link, useNavigate } from 'react-router-dom';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile'; import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
const PersonalSetting = () => { const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
let navigate = useNavigate();
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
wechat_verification_code: '', wechat_verification_code: '',
email_verification_code: '', email_verification_code: '',
email: '', email: '',
self_account_deletion_confirmation: ''
}); });
const [status, setStatus] = useState({}); const [status, setStatus] = useState({});
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
const [showEmailBindModal, setShowEmailBindModal] = useState(false); const [showEmailBindModal, setShowEmailBindModal] = useState(false);
const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false);
const [turnstileEnabled, setTurnstileEnabled] = useState(false); const [turnstileEnabled, setTurnstileEnabled] = useState(false);
const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState(''); const [turnstileToken, setTurnstileToken] = useState('');
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [disableButton, setDisableButton] = useState(false);
const [countdown, setCountdown] = useState(30);
const [affLink, setAffLink] = useState("");
const [systemToken, setSystemToken] = useState("");
useEffect(() => { useEffect(() => {
let status = localStorage.getItem('status'); let status = localStorage.getItem('status');
@@ -30,6 +40,19 @@ const PersonalSetting = () => {
} }
}, []); }, []);
useEffect(() => {
let countdownInterval = null;
if (disableButton && countdown > 0) {
countdownInterval = setInterval(() => {
setCountdown(countdown - 1);
}, 1000);
} else if (countdown === 0) {
setDisableButton(false);
setCountdown(30);
}
return () => clearInterval(countdownInterval); // Clean up on unmount
}, [disableButton, countdown]);
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
}; };
@@ -38,8 +61,10 @@ const PersonalSetting = () => {
const res = await API.get('/api/user/token'); const res = await API.get('/api/user/token');
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setSystemToken(data);
setAffLink("");
await copy(data); await copy(data);
showSuccess(`令牌已重置并已复制到剪贴板${data}`); showSuccess(`令牌已重置并已复制到剪贴板`);
} else { } else {
showError(message); showError(message);
} }
@@ -50,8 +75,42 @@ const PersonalSetting = () => {
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
let link = `${window.location.origin}/register?aff=${data}`; let link = `${window.location.origin}/register?aff=${data}`;
setAffLink(link);
setSystemToken("");
await copy(link); await copy(link);
showNotice(`邀请链接已复制到剪切板${link}`); showSuccess(`邀请链接已复制到剪切板`);
} else {
showError(message);
}
};
const handleAffLinkClick = async (e) => {
e.target.select();
await copy(e.target.value);
showSuccess(`邀请链接已复制到剪切板`);
};
const handleSystemTokenClick = async (e) => {
e.target.select();
await copy(e.target.value);
showSuccess(`系统令牌已复制到剪切板`);
};
const deleteAccount = async () => {
if (inputs.self_account_deletion_confirmation !== userState.user.username) {
showError('请输入你的账户名以确认删除!');
return;
}
const res = await API.delete('/api/user/self');
const { success, message } = res.data;
if (success) {
showSuccess('账户已删除!');
await API.get('/api/user/logout');
userDispatch({ type: 'logout' });
localStorage.removeItem('user');
navigate('/login');
} else { } else {
showError(message); showError(message);
} }
@@ -78,6 +137,7 @@ const PersonalSetting = () => {
}; };
const sendVerificationCode = async () => { const sendVerificationCode = async () => {
setDisableButton(true);
if (inputs.email === '') return; if (inputs.email === '') return;
if (turnstileEnabled && turnstileToken === '') { if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!'); showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
@@ -123,6 +183,28 @@ const PersonalSetting = () => {
</Button> </Button>
<Button onClick={generateAccessToken}>生成系统访问令牌</Button> <Button onClick={generateAccessToken}>生成系统访问令牌</Button>
<Button onClick={getAffLink}>复制邀请链接</Button> <Button onClick={getAffLink}>复制邀请链接</Button>
<Button onClick={() => {
setShowAccountDeleteModal(true);
}}>删除个人账户</Button>
{systemToken && (
<Form.Input
fluid
readOnly
value={systemToken}
onClick={handleSystemTokenClick}
style={{ marginTop: '10px' }}
/>
)}
{affLink && (
<Form.Input
fluid
readOnly
value={affLink}
onClick={handleAffLinkClick}
style={{ marginTop: '10px' }}
/>
)}
<Divider /> <Divider />
<Header as='h3'>账号绑定</Header> <Header as='h3'>账号绑定</Header>
{ {
@@ -195,8 +277,8 @@ const PersonalSetting = () => {
name='email' name='email'
type='email' type='email'
action={ action={
<Button onClick={sendVerificationCode} disabled={loading}> <Button onClick={sendVerificationCode} disabled={disableButton || loading}>
获取验证码 {disableButton ? `重新发送(${countdown})` : '获取验证码'}
</Button> </Button>
} }
/> />
@@ -217,6 +299,7 @@ const PersonalSetting = () => {
) : ( ) : (
<></> <></>
)} )}
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}>
<Button <Button
color='' color=''
fluid fluid
@@ -224,8 +307,69 @@ const PersonalSetting = () => {
onClick={bindEmail} onClick={bindEmail}
loading={loading} loading={loading}
> >
绑定 确认绑定
</Button> </Button>
<div style={{ width: '1rem' }}></div>
<Button
fluid
size='large'
onClick={() => setShowEmailBindModal(false)}
>
取消
</Button>
</div>
</Form>
</Modal.Description>
</Modal.Content>
</Modal>
<Modal
onClose={() => setShowAccountDeleteModal(false)}
onOpen={() => setShowAccountDeleteModal(true)}
open={showAccountDeleteModal}
size={'tiny'}
style={{ maxWidth: '450px' }}
>
<Modal.Header>危险操作</Modal.Header>
<Modal.Content>
<Message>您正在删除自己的帐户将清空所有数据且不可恢复</Message>
<Modal.Description>
<Form size='large'>
<Form.Input
fluid
placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`}
name='self_account_deletion_confirmation'
value={inputs.self_account_deletion_confirmation}
onChange={handleInputChange}
/>
{turnstileEnabled ? (
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
) : (
<></>
)}
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}>
<Button
color='red'
fluid
size='large'
onClick={deleteAccount}
loading={loading}
>
确认删除
</Button>
<div style={{ width: '1rem' }}></div>
<Button
fluid
size='large'
onClick={() => setShowAccountDeleteModal(false)}
>
取消
</Button>
</div>
</Form> </Form>
</Modal.Description> </Modal.Description>
</Modal.Content> </Modal.Content>

View File

@@ -1,13 +1,5 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react';
Button,
Form,
Grid,
Header,
Image,
Message,
Segment,
} from 'semantic-ui-react';
import { Link, useNavigate } from 'react-router-dom'; import { Link, useNavigate } from 'react-router-dom';
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile'; import Turnstile from 'react-turnstile';
@@ -18,7 +10,7 @@ const RegisterForm = () => {
password: '', password: '',
password2: '', password2: '',
email: '', email: '',
verification_code: '', verification_code: ''
}); });
const { username, password, password2 } = inputs; const { username, password, password2 } = inputs;
const [showEmailVerification, setShowEmailVerification] = useState(false); const [showEmailVerification, setShowEmailVerification] = useState(false);
@@ -178,7 +170,7 @@ const RegisterForm = () => {
<></> <></>
)} )}
<Button <Button
color='' color='green'
fluid fluid
size='large' size='large'
onClick={handleSubmit} onClick={handleSubmit}

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react'; import { Button, Divider, Form, Grid, Header, Input, Message } from 'semantic-ui-react';
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers'; import { API, removeTrailingSlash, showError } from '../helpers';
const SystemSetting = () => { const SystemSetting = () => {
let [inputs, setInputs] = useState({ let [inputs, setInputs] = useState({
@@ -26,9 +26,13 @@ const SystemSetting = () => {
TurnstileSiteKey: '', TurnstileSiteKey: '',
TurnstileSecretKey: '', TurnstileSecretKey: '',
RegisterEnabled: '', RegisterEnabled: '',
EmailDomainRestrictionEnabled: '',
EmailDomainWhitelist: ''
}); });
const [originInputs, setOriginInputs] = useState({}); const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]);
const [restrictedDomainInput, setRestrictedDomainInput] = useState('');
const getOptions = async () => { const getOptions = async () => {
const res = await API.get('/api/option/'); const res = await API.get('/api/option/');
@@ -38,8 +42,15 @@ const SystemSetting = () => {
data.forEach((item) => { data.forEach((item) => {
newInputs[item.key] = item.value; newInputs[item.key] = item.value;
}); });
setInputs(newInputs); setInputs({
...newInputs,
EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',')
});
setOriginInputs(newInputs); setOriginInputs(newInputs);
setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => {
return { key: item, text: item, value: item };
}));
} else { } else {
showError(message); showError(message);
} }
@@ -58,6 +69,7 @@ const SystemSetting = () => {
case 'GitHubOAuthEnabled': case 'GitHubOAuthEnabled':
case 'WeChatAuthEnabled': case 'WeChatAuthEnabled':
case 'TurnstileCheckEnabled': case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled': case 'RegisterEnabled':
value = inputs[key] === 'true' ? 'false' : 'true'; value = inputs[key] === 'true' ? 'false' : 'true';
break; break;
@@ -70,7 +82,12 @@ const SystemSetting = () => {
}); });
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
setInputs((inputs) => ({ ...inputs, [key]: value })); if (key === 'EmailDomainWhitelist') {
value = value.split(',');
}
setInputs((inputs) => ({
...inputs, [key]: value
}));
} else { } else {
showError(message); showError(message);
} }
@@ -88,7 +105,8 @@ const SystemSetting = () => {
name === 'WeChatServerToken' || name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' || name === 'WeChatAccountQRCodeImageURL' ||
name === 'TurnstileSiteKey' || name === 'TurnstileSiteKey' ||
name === 'TurnstileSecretKey' name === 'TurnstileSecretKey' ||
name === 'EmailDomainWhitelist'
) { ) {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
} else { } else {
@@ -125,6 +143,16 @@ const SystemSetting = () => {
} }
}; };
const submitEmailDomainWhitelist = async () => {
if (
originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') &&
inputs.SMTPToken !== ''
) {
await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(','));
}
};
const submitWeChat = async () => { const submitWeChat = async () => {
if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) {
await updateOption( await updateOption(
@@ -173,6 +201,22 @@ const SystemSetting = () => {
} }
}; };
const submitNewRestrictedDomain = () => {
const localDomainList = inputs.EmailDomainWhitelist;
if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) {
setRestrictedDomainInput('');
setInputs({
...inputs,
EmailDomainWhitelist: [...localDomainList, restrictedDomainInput],
});
setEmailDomainWhitelist([...EmailDomainWhitelist, {
key: restrictedDomainInput,
text: restrictedDomainInput,
value: restrictedDomainInput,
}]);
}
}
return ( return (
<Grid columns={1}> <Grid columns={1}>
<Grid.Column> <Grid.Column>
@@ -239,6 +283,54 @@ const SystemSetting = () => {
/> />
</Form.Group> </Form.Group>
<Divider /> <Divider />
<Header as='h3'>
配置邮箱域名白名单
<Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader>
</Header>
<Form.Group widths={3}>
<Form.Checkbox
label='启用邮箱域名白名单'
name='EmailDomainRestrictionEnabled'
onChange={handleInputChange}
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={2}>
<Form.Dropdown
label='允许的邮箱域名'
placeholder='允许的邮箱域名'
name='EmailDomainWhitelist'
required
fluid
multiple
selection
onChange={handleInputChange}
value={inputs.EmailDomainWhitelist}
autoComplete='new-password'
options={EmailDomainWhitelist}
/>
<Form.Input
label='添加新的允许的邮箱域名'
action={
<Button type='button' onClick={() => {
submitNewRestrictedDomain();
}}>填入</Button>
}
onKeyDown={(e) => {
if (e.key === 'Enter') {
submitNewRestrictedDomain();
}
}}
autoComplete='new-password'
placeholder='输入新的允许的邮箱域名'
value={restrictedDomainInput}
onChange={(e, { value }) => {
setRestrictedDomainInput(value);
}}
/>
</Form.Group>
<Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button>
<Divider />
<Header as='h3'> <Header as='h3'>
配置 SMTP 配置 SMTP
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader> <Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
@@ -284,7 +376,7 @@ const SystemSetting = () => {
onChange={handleInputChange} onChange={handleInputChange}
type='password' type='password'
autoComplete='new-password' autoComplete='new-password'
value={inputs.SMTPToken} checked={inputs.RegisterEnabled === 'true'}
placeholder='敏感信息不会发送到前端显示' placeholder='敏感信息不会发送到前端显示'
/> />
</Form.Group> </Form.Group>

View File

@@ -1,11 +1,22 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Modal, Pagination, Popup, Table } from 'semantic-ui-react'; import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers'; import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers';
import { ITEMS_PER_PAGE } from '../constants'; import { ITEMS_PER_PAGE } from '../constants';
import { renderQuota } from '../helpers/render'; import { renderQuota } from '../helpers/render';
const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
];
const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
];
function renderTimestamp(timestamp) { function renderTimestamp(timestamp) {
return ( return (
<> <>
@@ -68,6 +79,84 @@ const TokensTable = () => {
const refresh = async () => { const refresh = async () => {
setLoading(true); setLoading(true);
await loadTokens(activePage - 1); await loadTokens(activePage - 1);
};
const onCopy = async (type, key) => {
let status = localStorage.getItem('status');
let serverAddress = '';
if (status) {
status = JSON.parse(status);
serverAddress = status.server_address;
}
if (serverAddress === '') {
serverAddress = window.location.origin;
}
let encodedServerAddress = encodeURIComponent(serverAddress);
const nextLink = localStorage.getItem('chat_link');
let nextUrl;
if (nextLink) {
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
} else {
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}
let url;
switch (type) {
case 'ama':
url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
break;
case 'opencat':
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
break;
case 'next':
url = nextUrl;
break;
default:
url = `sk-${key}`;
}
if (await copy(url)) {
showSuccess('已复制到剪贴板!');
} else {
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
setSearchKeyword(url);
}
};
const onOpenLink = async (type, key) => {
let status = localStorage.getItem('status');
let serverAddress = '';
if (status) {
status = JSON.parse(status);
serverAddress = status.server_address;
}
if (serverAddress === '') {
serverAddress = window.location.origin;
}
let encodedServerAddress = encodeURIComponent(serverAddress);
const chatLink = localStorage.getItem('chat_link');
let defaultUrl;
if (chatLink) {
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
} else {
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}
let url;
switch (type) {
case 'ama':
url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
break;
case 'opencat':
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
break;
default:
url = defaultUrl;
}
window.open(url, '_blank');
} }
useEffect(() => { useEffect(() => {
@@ -235,21 +324,51 @@ const TokensTable = () => {
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell> <Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
<Table.Cell> <Table.Cell>
<div> <div>
<Button <Button.Group color='green' size={'small'}>
size={'small'} <Button
positive size={'small'}
onClick={async () => { positive
let key = "sk-" + token.key; onClick={async () => {
if (await copy(key)) { await onCopy('', token.key);
showSuccess('已复制到剪贴板!'); }}
} else { >
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); 复制
setSearchKeyword(key); </Button>
} <Dropdown
}} className='button icon'
> floating
复制 options={COPY_OPTIONS.map(option => ({
</Button> ...option,
onClick: async () => {
await onCopy(option.value, token.key);
}
}))}
trigger={<></>}
/>
</Button.Group>
{' '}
<Button.Group color='blue' size={'small'}>
<Button
size={'small'}
positive
onClick={() => {
onOpenLink('', token.key);
}}>
聊天
</Button>
<Dropdown
className="button icon"
floating
options={OPEN_LINK_OPTIONS.map(option => ({
...option,
onClick: async () => {
await onOpenLink(option.value, token.key);
}
}))}
trigger={<></>}
/>
</Button.Group>
{' '}
<Popup <Popup
trigger={ trigger={
<Button size='small' negative> <Button size='small' negative>

View File

@@ -227,7 +227,7 @@ const UsersTable = () => {
content={user.email ? user.email : '未绑定邮箱地址'} content={user.email ? user.email : '未绑定邮箱地址'}
key={user.username} key={user.username}
header={user.display_name ? user.display_name : user.username} header={user.display_name ? user.display_name : user.username}
trigger={<span>{renderText(user.username, 10)}</span>} trigger={<span>{renderText(user.username, 15)}</span>}
hoverable hoverable
/> />
</Table.Cell> </Table.Cell>

View File

@@ -1,17 +1,20 @@
export const CHANNEL_OPTIONS = [ export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' }, { key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 14, text: 'Anthropic', value: 14, color: 'black' }, { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
{ key: 8, text: '自定义', value: 8, color: 'pink' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 3, text: 'Azure', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 11, text: 'PaLM', value: 11, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 15, text: 'Baidu', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 2, text: 'API2D', value: 2, color: 'blue' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 6, text: 'OpenAI Max', value: 6, color: 'violet' }, { key: 2, text: '代理API2D', value: 2, color: 'blue' },
{ key: 7, text: 'OhMyGPT', value: 7, color: 'purple' }, { key: 5, text: '代理OpenAI-SB', value: 5, color: 'brown' },
{ key: 9, text: 'AI.LS', value: 9, color: 'yellow' }, { key: 7, text: '代理OhMyGPT', value: 7, color: 'purple' },
{ key: 10, text: 'AI Proxy', value: 10, color: 'purple' }, { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' },
{ key: 12, text: 'API2GPT', value: 12, color: 'blue' }, { key: 4, text: '代理CloseAI', value: 4, color: 'teal' },
{ key: 13, text: 'AIGC2D', value: 13, color: 'purple' } { key: 6, text: '代理OpenAI Max', value: 6, color: 'violet' },
{ key: 9, text: '代理AI.LS', value: 9, color: 'yellow' },
{ key: 12, text: '代理API2GPT', value: 12, color: 'blue' },
{ key: 13, text: '代理AIGC2D', value: 13, color: 'purple' }
]; ];

View File

@@ -1,5 +1,5 @@
export const toastConstants = { export const toastConstants = {
SUCCESS_TIMEOUT: 500, SUCCESS_TIMEOUT: 1500,
INFO_TIMEOUT: 3000, INFO_TIMEOUT: 3000,
ERROR_TIMEOUT: 5000, ERROR_TIMEOUT: 5000,
WARNING_TIMEOUT: 10000, WARNING_TIMEOUT: 10000,

View File

@@ -46,9 +46,7 @@ const About = () => {
about.startsWith('https://') ? <iframe about.startsWith('https://') ? <iframe
src={about} src={about}
style={{ width: '100%', height: '100vh', border: 'none' }} style={{ width: '100%', height: '100vh', border: 'none' }}
/> : <Segment> /> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div>
<div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div>
</Segment>
} }
</> </>
} }

View File

@@ -35,6 +35,30 @@ const EditChannel = () => {
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type' && inputs.models.length === 0) {
let localModels = [];
switch (value) {
case 14:
localModels = ['claude-instant-1', 'claude-2'];
break;
case 11:
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-v1', 'qwen-plus-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = ['SparkDesk'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}
}; };
const loadChannel = async () => { const loadChannel = async () => {
@@ -132,7 +156,10 @@ const EditChannel = () => {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
} }
if (localInputs.type === 3 && localInputs.other === '') { if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-03-15-preview'; localInputs.other = '2023-06-01-preview';
}
if (localInputs.model_mapping === '') {
localInputs.model_mapping = '{}';
} }
let res; let res;
localInputs.models = localInputs.models.join(','); localInputs.models = localInputs.models.join(',');
@@ -192,7 +219,7 @@ const EditChannel = () => {
<Form.Input <Form.Input
label='默认 API 版本' label='默认 API 版本'
name='other' name='other'
placeholder={'请输入默认 API 版本例如2023-03-15-preview该配置可以被实际的请求查询参数所覆盖'} placeholder={'请输入默认 API 版本例如2023-06-01-preview该配置可以被实际的请求查询参数所覆盖'}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.other} value={inputs.other}
autoComplete='new-password' autoComplete='new-password'
@@ -215,26 +242,12 @@ const EditChannel = () => {
</Form.Field> </Form.Field>
) )
} }
{
inputs.type !== 3 && inputs.type !== 8 && (
<Form.Field>
<Form.Input
label='镜像'
name='base_url'
placeholder={'此项可选输入镜像站地址格式为https://domain.com'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='名称' label='名称'
required required
name='name' name='name'
placeholder={'请输入名称'} placeholder={'请为渠道命名'}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.name} value={inputs.name}
autoComplete='new-password' autoComplete='new-password'
@@ -243,7 +256,7 @@ const EditChannel = () => {
<Form.Field> <Form.Field>
<Form.Dropdown <Form.Dropdown
label='分组' label='分组'
placeholder={'请选择分组'} placeholder={'请选择可以使用该渠道的分组'}
name='groups' name='groups'
required required
fluid fluid
@@ -260,7 +273,7 @@ const EditChannel = () => {
<Form.Field> <Form.Field>
<Form.Dropdown <Form.Dropdown
label='模型' label='模型'
placeholder={'请选择该道所支持的模型'} placeholder={'请选择该道所支持的模型'}
name='models' name='models'
required required
fluid fluid
@@ -284,8 +297,8 @@ const EditChannel = () => {
}}>清除所有模型</Button> }}>清除所有模型</Button>
<Input <Input
action={ action={
<Button type={'button'} onClick={()=>{ <Button type={'button'} onClick={() => {
if (customModel.trim() === "") return; if (customModel.trim() === '') return;
if (inputs.models.includes(customModel)) return; if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models]; let localModels = [...inputs.models];
localModels.push(customModel); localModels.push(customModel);
@@ -293,9 +306,9 @@ const EditChannel = () => {
localModelOptions.push({ localModelOptions.push({
key: customModel, key: customModel,
text: customModel, text: customModel,
value: customModel, value: customModel
}); });
setModelOptions(modelOptions=>{ setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions]; return [...modelOptions, ...localModelOptions];
}); });
setCustomModel(''); setCustomModel('');
@@ -311,8 +324,8 @@ const EditChannel = () => {
</div> </div>
<Form.Field> <Form.Field>
<Form.TextArea <Form.TextArea
label='模型映射' label='模型重定向'
placeholder={`此项可选,为一个 JSON 文本,键为用户请求模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`} placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
name='model_mapping' name='model_mapping'
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.model_mapping} value={inputs.model_mapping}
@@ -337,7 +350,7 @@ const EditChannel = () => {
label='密钥' label='密钥'
name='key' name='key'
required required
placeholder={inputs.type === 15 ? "请输入 access token当前版本暂不支持自动刷新请每 30 天更新一次" : '请输入密钥'} placeholder={inputs.type === 15 ? '请输入 access token当前版本暂不支持自动刷新请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.key} value={inputs.key}
autoComplete='new-password' autoComplete='new-password'
@@ -354,7 +367,21 @@ const EditChannel = () => {
/> />
) )
} }
<Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button> {
inputs.type !== 3 && inputs.type !== 8 && (
<Form.Field>
<Form.Input
label='代理'
name='base_url'
placeholder={'此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
/>
</Form.Field>
)
}
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
</Form> </Form>
</Segment> </Segment>
</> </>

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams } from 'react-router-dom'; import { useParams, useNavigate } from 'react-router-dom';
import { API, showError, showSuccess, timestamp2string } from '../../helpers'; import { API, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
@@ -17,11 +17,13 @@ const EditToken = () => {
}; };
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
const { name, remain_quota, expired_time, unlimited_quota } = inputs; const { name, remain_quota, expired_time, unlimited_quota } = inputs;
const navigate = useNavigate();
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
}; };
const handleCancel = () => {
navigate("/token");
}
const setExpiredTime = (month, day, hour, minute) => { const setExpiredTime = (month, day, hour, minute) => {
let now = new Date(); let now = new Date();
let timestamp = now.getTime() / 1000; let timestamp = now.getTime() / 1000;
@@ -83,7 +85,7 @@ const EditToken = () => {
if (isEdit) { if (isEdit) {
showSuccess('令牌更新成功!'); showSuccess('令牌更新成功!');
} else { } else {
showSuccess('令牌创建成功!'); showSuccess('令牌创建成功,请在列表页面点击复制获取令牌');
setInputs(originInputs); setInputs(originInputs);
} }
} else { } else {
@@ -150,8 +152,9 @@ const EditToken = () => {
</Form.Field> </Form.Field>
<Button type={'button'} onClick={() => { <Button type={'button'} onClick={() => {
setUnlimitedQuota(); setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button> }}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button>
<Button positive onClick={submit}>提交</Button> <Button floated='right' positive onClick={submit}>提交</Button>
<Button floated='right' onClick={handleCancel}>取消</Button>
</Form> </Form>
</Segment> </Segment>
</> </>

View File

@@ -7,24 +7,32 @@ const TopUp = () => {
const [redemptionCode, setRedemptionCode] = useState(''); const [redemptionCode, setRedemptionCode] = useState('');
const [topUpLink, setTopUpLink] = useState(''); const [topUpLink, setTopUpLink] = useState('');
const [userQuota, setUserQuota] = useState(0); const [userQuota, setUserQuota] = useState(0);
const [isSubmitting, setIsSubmitting] = useState(false);
const topUp = async () => { const topUp = async () => {
if (redemptionCode === '') { if (redemptionCode === '') {
showInfo('请输入充值码!') showInfo('请输入充值码!')
return; return;
} }
const res = await API.post('/api/user/topup', { setIsSubmitting(true);
key: redemptionCode try {
}); const res = await API.post('/api/user/topup', {
const { success, message, data } = res.data; key: redemptionCode
if (success) {
showSuccess('充值成功!');
setUserQuota((quota) => {
return quota + data;
}); });
setRedemptionCode(''); const { success, message, data } = res.data;
} else { if (success) {
showError(message); showSuccess('充值成功!');
setUserQuota((quota) => {
return quota + data;
});
setRedemptionCode('');
} else {
showError(message);
}
} catch (err) {
showError('请求失败');
} finally {
setIsSubmitting(false);
} }
}; };
@@ -74,8 +82,8 @@ const TopUp = () => {
<Button color='green' onClick={openTopUpLink}> <Button color='green' onClick={openTopUpLink}>
获取兑换码 获取兑换码
</Button> </Button>
<Button color='yellow' onClick={topUp}> <Button color='yellow' onClick={topUp} disabled={isSubmitting}>
充值 {isSubmitting ? '兑换中...' : '兑换'}
</Button> </Button>
</Form> </Form>
</Grid.Column> </Grid.Column>
@@ -92,5 +100,4 @@ const TopUp = () => {
); );
}; };
export default TopUp;
export default TopUp;

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Segment } from 'semantic-ui-react'; import { Button, Form, Header, Segment } from 'semantic-ui-react';
import { useParams } from 'react-router-dom'; import { useParams, useNavigate } from 'react-router-dom';
import { API, showError, showSuccess } from '../../helpers'; import { API, showError, showSuccess } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
@@ -36,7 +36,10 @@ const EditUser = () => {
showError(error.message); showError(error.message);
} }
}; };
const navigate = useNavigate();
const handleCancel = () => {
navigate("/setting");
}
const loadUser = async () => { const loadUser = async () => {
let res = undefined; let res = undefined;
if (userId) { if (userId) {
@@ -176,6 +179,7 @@ const EditUser = () => {
readOnly readOnly
/> />
</Form.Field> </Form.Field>
<Button onClick={handleCancel}>取消</Button>
<Button positive onClick={submit}>提交</Button> <Button positive onClick={submit}>提交</Button>
</Form> </Form>
</Segment> </Segment>