mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			81 Commits
		
	
	
		
			v0.5.5-alp
			...
			v0.5.9-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					8f5b83562b | ||
| 
						 | 
					b7570d5c77 | ||
| 
						 | 
					0e73418cdf | ||
| 
						 | 
					9889377f0e | ||
| 
						 | 
					b273464e77 | ||
| 
						 | 
					b4e43d97fd | ||
| 
						 | 
					3347a44023 | ||
| 
						 | 
					923e24534b | ||
| 
						 | 
					b4d67ca614 | ||
| 
						 | 
					d85e356b6e | ||
| 
						 | 
					495fc628e4 | ||
| 
						 | 
					76f9288c34 | ||
| 
						 | 
					915d13fdd4 | ||
| 
						 | 
					969f539777 | ||
| 
						 | 
					54e5f8ecd2 | ||
| 
						 | 
					34d517cfa2 | ||
| 
						 | 
					ddcaf95f5f | ||
| 
						 | 
					1d15157f7d | ||
| 
						 | 
					de7b9710a5 | ||
| 
						 | 
					58bb3ab6f6 | ||
| 
						 | 
					d306cb5229 | ||
| 
						 | 
					6c5307d0c4 | ||
| 
						 | 
					7c4505bdfc | ||
| 
						 | 
					9d43ec57d8 | ||
| 
						 | 
					e5311892d1 | ||
| 
						 | 
					bc7c9105f4 | ||
| 
						 | 
					3fe76c8af7 | ||
| 
						 | 
					c70c614018 | ||
| 
						 | 
					0d87de697c | ||
| 
						 | 
					aec343dc38 | ||
| 
						 | 
					89d458b9cf | ||
| 
						 | 
					63fafba112 | ||
| 
						 | 
					a398f35968 | ||
| 
						 | 
					57aa637c77 | ||
| 
						 | 
					3b483639a4 | ||
| 
						 | 
					22980b4c44 | ||
| 
						 | 
					64cdb7eafb | ||
| 
						 | 
					824444244b | ||
| 
						 | 
					fbe9985f57 | ||
| 
						 | 
					a27a5bcc06 | ||
| 
						 | 
					e28d4b1741 | ||
| 
						 | 
					f073592d39 | ||
| 
						 | 
					fa41ca9805 | ||
| 
						 | 
					e338de45b6 | ||
| 
						 | 
					114587b46f | ||
| 
						 | 
					b4b4acc288 | ||
| 
						 | 
					d663de3e3a | ||
| 
						 | 
					a85ecace2e | ||
| 
						 | 
					fbdea91ea1 | ||
| 
						 | 
					8d34b7a77e | ||
| 
						 | 
					cbd62011b8 | ||
| 
						 | 
					4701897e2e | ||
| 
						 | 
					0f6c132a80 | ||
| 
						 | 
					3cac45dc85 | ||
| 
						 | 
					47c08c72ce | ||
| 
						 | 
					53b2cace0b | ||
| 
						 | 
					f0fc991b44 | ||
| 
						 | 
					594f06e7b0 | ||
| 
						 | 
					197d1d7a9d | ||
| 
						 | 
					f9b748c2ca | ||
| 
						 | 
					fd98463611 | ||
| 
						 | 
					f5a1cd3463 | ||
| 
						 | 
					8651451e53 | ||
| 
						 | 
					1c5bb97a42 | ||
| 
						 | 
					de868e4e4e | ||
| 
						 | 
					1d258cc898 | ||
| 
						 | 
					37e09d764c | ||
| 
						 | 
					159b9e3369 | ||
| 
						 | 
					92001986db | ||
| 
						 | 
					a5647b1ea7 | ||
| 
						 | 
					215e54fc96 | ||
| 
						 | 
					ecf8a6d875 | ||
| 
						 | 
					24df3e5f62 | ||
| 
						 | 
					12ef9679a7 | ||
| 
						 | 
					328aa68255 | ||
| 
						 | 
					4335f005a6 | ||
| 
						 | 
					fe26a1448d | ||
| 
						 | 
					42451d9d02 | ||
| 
						 | 
					25c4c111ab | ||
| 
						 | 
					0d50ad4b2b | ||
| 
						 | 
					959bcdef88 | 
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -5,3 +5,5 @@ upload
 | 
			
		||||
*.db
 | 
			
		||||
build
 | 
			
		||||
*.db-journal
 | 
			
		||||
logs
 | 
			
		||||
data
 | 
			
		||||
@@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
 | 
			
		||||
 | 
			
		||||
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
 | 
			
		||||
 | 
			
		||||
[](https://zeabur.com/templates/7Q0KO3)
 | 
			
		||||
 | 
			
		||||
1. First, fork the code.
 | 
			
		||||
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
 | 
			
		||||
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
 | 
			
		||||
 
 | 
			
		||||
@@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for
 | 
			
		||||
 | 
			
		||||
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
 | 
			
		||||
 | 
			
		||||
[](https://zeabur.com/templates/7Q0KO3)
 | 
			
		||||
 | 
			
		||||
1. まず、コードをフォークする。
 | 
			
		||||
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
 | 
			
		||||
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										100
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								README.md
									
									
									
									
									
								
							@@ -51,14 +51,17 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
  <a href="https://iamazing.cn/page/reward">赞赏支持</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **Note**
 | 
			
		||||
> [!NOTE]
 | 
			
		||||
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
 | 
			
		||||
> 
 | 
			
		||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
 | 
			
		||||
 | 
			
		||||
> **Warning**
 | 
			
		||||
> [!WARNING]
 | 
			
		||||
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
 | 
			
		||||
 | 
			
		||||
> [!WARNING]
 | 
			
		||||
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
 | 
			
		||||
 | 
			
		||||
## 功能
 | 
			
		||||
1. 支持多种大模型:
 | 
			
		||||
   + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
 | 
			
		||||
@@ -69,9 +72,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
 | 
			
		||||
   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
			
		||||
   + [x] [360 智脑](https://ai.360.cn)
 | 
			
		||||
   + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
 | 
			
		||||
2. 支持配置镜像以及众多第三方代理服务:
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
 | 
			
		||||
   + [x] [CloseAI](https://referer.shadowai.xyz/r/2412)
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
			
		||||
@@ -88,26 +92,33 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
12. 支持**用户邀请奖励**。
 | 
			
		||||
13. 支持以美元为单位显示额度。
 | 
			
		||||
14. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
15. 支持模型映射,重定向用户的请求模型。
 | 
			
		||||
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。
 | 
			
		||||
16. 支持失败自动重试。
 | 
			
		||||
17. 支持绘图接口。
 | 
			
		||||
18. 支持丰富的**自定义**设置,
 | 
			
		||||
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
 | 
			
		||||
19. 支持丰富的**自定义**设置,
 | 
			
		||||
    1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
    2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
19. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
20. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
21. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
 | 
			
		||||
21. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
22. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
 | 
			
		||||
## 部署
 | 
			
		||||
### 基于 Docker 进行部署
 | 
			
		||||
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
 | 
			
		||||
```shell
 | 
			
		||||
# 使用 SQLite 的部署命令:
 | 
			
		||||
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
 | 
			
		||||
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
 | 
			
		||||
# 例如:
 | 
			
		||||
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
 | 
			
		||||
 | 
			
		||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 | 
			
		||||
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 | 
			
		||||
 | 
			
		||||
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
 | 
			
		||||
 | 
			
		||||
@@ -149,6 +160,19 @@ sudo service nginx restart
 | 
			
		||||
 | 
			
		||||
初始账号用户名为 `root`,密码为 `123456`。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### 基于 Docker Compose 进行部署
 | 
			
		||||
 | 
			
		||||
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
 | 
			
		||||
docker-compose up -d
 | 
			
		||||
 | 
			
		||||
# 查看部署状态
 | 
			
		||||
docker-compose ps
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 手动部署
 | 
			
		||||
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
 | 
			
		||||
   ```shell
 | 
			
		||||
@@ -236,7 +260,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
<summary><strong>部署到 Zeabur</strong></summary>
 | 
			
		||||
<div>
 | 
			
		||||
 | 
			
		||||
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。
 | 
			
		||||
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
 | 
			
		||||
 | 
			
		||||
[](https://zeabur.com/templates/7Q0KO3)
 | 
			
		||||
 | 
			
		||||
1. 首先 fork 一份代码。
 | 
			
		||||
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
 | 
			
		||||
@@ -251,6 +277,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
</div>
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
<details>
 | 
			
		||||
<summary><strong>部署到 Render</strong></summary>
 | 
			
		||||
<div>
 | 
			
		||||
 | 
			
		||||
> Render 提供免费额度,绑卡后可以进一步提升额度
 | 
			
		||||
 | 
			
		||||
Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com
 | 
			
		||||
 | 
			
		||||
</div>
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
## 配置
 | 
			
		||||
系统本身开箱即用。
 | 
			
		||||
 | 
			
		||||
@@ -269,13 +306,20 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
 | 
			
		||||
注意,具体的 API Base 的格式取决于你所使用的客户端。
 | 
			
		||||
 | 
			
		||||
例如对于 OpenAI 的官方库:
 | 
			
		||||
```bash
 | 
			
		||||
OPENAI_API_KEY="sk-xxxxxx"
 | 
			
		||||
OPENAI_API_BASE="https://<HOST>:<PORT>/v1" 
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```mermaid
 | 
			
		||||
graph LR
 | 
			
		||||
    A(用户)
 | 
			
		||||
    A --->|请求| B(One API)
 | 
			
		||||
    A --->|使用 One API 分发的 key 进行请求| B(One API)
 | 
			
		||||
    B -->|中继请求| C(OpenAI)
 | 
			
		||||
    B -->|中继请求| D(Azure)
 | 
			
		||||
    B -->|中继请求| E(其他下游渠道)
 | 
			
		||||
    B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
 | 
			
		||||
    B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
 | 
			
		||||
@@ -303,29 +347,35 @@ graph LR
 | 
			
		||||
     + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
 | 
			
		||||
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
 | 
			
		||||
   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
			
		||||
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
 | 
			
		||||
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
			
		||||
   + 例子:`MEMORY_CACHE_ENABLED=true`
 | 
			
		||||
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
 | 
			
		||||
   + 例子:`SYNC_FREQUENCY=60`
 | 
			
		||||
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
 | 
			
		||||
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
 | 
			
		||||
   + 例子:`NODE_TYPE=slave`
 | 
			
		||||
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
 | 
			
		||||
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
 | 
			
		||||
   + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
 | 
			
		||||
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
 | 
			
		||||
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
 | 
			
		||||
   + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 | 
			
		||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
 | 
			
		||||
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
 | 
			
		||||
    + 例子:`POLLING_INTERVAL=5`
 | 
			
		||||
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
			
		||||
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
			
		||||
    + 例子:`BATCH_UPDATE_ENABLED=true`
 | 
			
		||||
    + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
 | 
			
		||||
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
 | 
			
		||||
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
 | 
			
		||||
    + 例子:`BATCH_UPDATE_INTERVAL=5`
 | 
			
		||||
12. 请求频率限制:
 | 
			
		||||
13. 请求频率限制:
 | 
			
		||||
    + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
 | 
			
		||||
    + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
 | 
			
		||||
14. 编码器缓存设置:
 | 
			
		||||
    + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
 | 
			
		||||
    + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
 | 
			
		||||
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
   + 例子:`--port 3000`
 | 
			
		||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。
 | 
			
		||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
 | 
			
		||||
   + 例子:`--log-dir ./logs`
 | 
			
		||||
3. `--version`: 打印系统版本号并退出。
 | 
			
		||||
4. `--help`: 查看命令的使用帮助和参数说明。
 | 
			
		||||
@@ -360,6 +410,12 @@ https://openai.justsong.cn
 | 
			
		||||
   + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
 | 
			
		||||
6. 报错:`当前分组负载已饱和,请稍后再试`
 | 
			
		||||
   + 上游通道 429 了。
 | 
			
		||||
7. 升级之后我的数据会丢失吗?
 | 
			
		||||
   + 如果使用 MySQL,不会。
 | 
			
		||||
   + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
 | 
			
		||||
8. 升级之前数据库需要做变更吗?
 | 
			
		||||
   + 一般情况下不需要,系统将在初始化的时候自动调整。
 | 
			
		||||
   + 如果需要的话,我会在更新日志中说明,并给出脚本。
 | 
			
		||||
 | 
			
		||||
## 相关项目
 | 
			
		||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
 | 
			
		||||
 
 | 
			
		||||
@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
 | 
			
		||||
var DisplayInCurrencyEnabled = true
 | 
			
		||||
var DisplayTokenStatEnabled = true
 | 
			
		||||
 | 
			
		||||
var UsingSQLite = false
 | 
			
		||||
 | 
			
		||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
 | 
			
		||||
 | 
			
		||||
var SessionSecret = uuid.New().String()
 | 
			
		||||
var SQLitePath = "one-api.db"
 | 
			
		||||
 | 
			
		||||
var OptionMap map[string]string
 | 
			
		||||
var OptionMapRWMutex sync.RWMutex
 | 
			
		||||
@@ -56,6 +53,7 @@ var EmailDomainWhitelist = []string{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
 | 
			
		||||
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
 | 
			
		||||
 | 
			
		||||
var LogConsumeEnabled = true
 | 
			
		||||
 | 
			
		||||
@@ -92,11 +90,17 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 | 
			
		||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
			
		||||
var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
			
		||||
 | 
			
		||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 | 
			
		||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
 | 
			
		||||
 | 
			
		||||
var BatchUpdateEnabled = false
 | 
			
		||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
 | 
			
		||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RequestIdKey = "X-Oneapi-Request-Id"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
@@ -153,7 +157,8 @@ const (
 | 
			
		||||
const (
 | 
			
		||||
	ChannelStatusUnknown          = 0
 | 
			
		||||
	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	ChannelStatusDisabled = 2 // also don't use 0
 | 
			
		||||
	ChannelStatusManuallyDisabled = 2 // also don't use 0
 | 
			
		||||
	ChannelStatusAutoDisabled     = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -180,6 +185,7 @@ const (
 | 
			
		||||
	ChannelTypeOpenRouter     = 20
 | 
			
		||||
	ChannelTypeAIProxyLibrary = 21
 | 
			
		||||
	ChannelTypeFastGPT        = 22
 | 
			
		||||
	ChannelTypeTencent        = 23
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
@@ -206,4 +212,5 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"https://openrouter.ai/api",         // 20
 | 
			
		||||
	"https://api.aiproxy.io",            // 21
 | 
			
		||||
	"https://fastgpt.run/api/openapi",   // 22
 | 
			
		||||
	"https://hunyuan.cloud.tencent.com", //23
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								common/database.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
var UsingSQLite = false
 | 
			
		||||
var UsingPostgreSQL = false
 | 
			
		||||
 | 
			
		||||
var SQLitePath = "one-api.db"
 | 
			
		||||
@@ -1,11 +1,13 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/smtp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SendEmail(subject string, receiver string, content string) error {
 | 
			
		||||
@@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error {
 | 
			
		||||
		SMTPFrom = SMTPAccount
 | 
			
		||||
	}
 | 
			
		||||
	encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
 | 
			
		||||
 | 
			
		||||
	// Extract domain from SMTPFrom
 | 
			
		||||
	parts := strings.Split(SMTPFrom, "@")
 | 
			
		||||
	var domain string
 | 
			
		||||
	if len(parts) > 1 {
 | 
			
		||||
		domain = parts[1]
 | 
			
		||||
	}
 | 
			
		||||
	// Generate a unique Message-ID
 | 
			
		||||
	buf := make([]byte, 16)
 | 
			
		||||
	_, err := rand.Read(buf)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	messageId := fmt.Sprintf("<%x@%s>", buf, domain)
 | 
			
		||||
 | 
			
		||||
	mail := []byte(fmt.Sprintf("To: %s\r\n"+
 | 
			
		||||
		"From: %s<%s>\r\n"+
 | 
			
		||||
		"Subject: %s\r\n"+
 | 
			
		||||
		"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
 | 
			
		||||
		"Date: %s\r\n"+
 | 
			
		||||
		"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
 | 
			
		||||
		receiver, SystemName, SMTPFrom, encodedSubject, content))
 | 
			
		||||
		receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
 | 
			
		||||
	auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
 | 
			
		||||
	addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
 | 
			
		||||
	to := strings.Split(receiver, ";")
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	if SMTPPort == 465 {
 | 
			
		||||
		tlsConfig := &tls.Config{
 | 
			
		||||
			InsecureSkipVerify: true,
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
@@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	contentType := c.Request.Header.Get("Content-Type")
 | 
			
		||||
	if strings.HasPrefix(contentType, "application/json") {
 | 
			
		||||
		err = json.Unmarshal(requestBody, &v)
 | 
			
		||||
	} else {
 | 
			
		||||
		// skip for now
 | 
			
		||||
		// TODO: someday non json request have variant model, we will need to implementation this
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ var (
 | 
			
		||||
	Port         = flag.Int("port", 3000, "the listening port")
 | 
			
		||||
	PrintVersion = flag.Bool("version", false, "print version and exit")
 | 
			
		||||
	PrintHelp    = flag.Bool("help", false, "print help and exit")
 | 
			
		||||
	LogDir       = flag.String("log-dir", "", "specify the log directory")
 | 
			
		||||
	LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func printHelp() {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,29 +1,47 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetupGinLog() {
 | 
			
		||||
const (
 | 
			
		||||
	loggerINFO  = "INFO"
 | 
			
		||||
	loggerWarn  = "WARN"
 | 
			
		||||
	loggerError = "ERR"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const maxLogCount = 1000000
 | 
			
		||||
 | 
			
		||||
var logCount int
 | 
			
		||||
var setupLogLock sync.Mutex
 | 
			
		||||
var setupLogWorking bool
 | 
			
		||||
 | 
			
		||||
func SetupLogger() {
 | 
			
		||||
	if *LogDir != "" {
 | 
			
		||||
		commonLogPath := filepath.Join(*LogDir, "common.log")
 | 
			
		||||
		errorLogPath := filepath.Join(*LogDir, "error.log")
 | 
			
		||||
		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		ok := setupLogLock.TryLock()
 | 
			
		||||
		if !ok {
 | 
			
		||||
			log.Println("setup log is already working")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer func() {
 | 
			
		||||
			setupLogLock.Unlock()
 | 
			
		||||
			setupLogWorking = false
 | 
			
		||||
		}()
 | 
			
		||||
		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
 | 
			
		||||
		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("failed to open log file")
 | 
			
		||||
		}
 | 
			
		||||
		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("failed to open log file")
 | 
			
		||||
		}
 | 
			
		||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
 | 
			
		||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
 | 
			
		||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
 | 
			
		||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -37,6 +55,36 @@ func SysError(s string) {
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogInfo(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerINFO, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogWarn(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerWarn, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogError(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerError, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logHelper(ctx context.Context, level string, msg string) {
 | 
			
		||||
	writer := gin.DefaultErrorWriter
 | 
			
		||||
	if level == loggerINFO {
 | 
			
		||||
		writer = gin.DefaultWriter
 | 
			
		||||
	}
 | 
			
		||||
	id := ctx.Value(RequestIdKey)
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
 | 
			
		||||
	logCount++ // we don't need accurate count, so no lock here
 | 
			
		||||
	if logCount > maxLogCount && !setupLogWorking {
 | 
			
		||||
		logCount = 0
 | 
			
		||||
		setupLogWorking = true
 | 
			
		||||
		go func() {
 | 
			
		||||
			SetupLogger()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FatalLog(v ...any) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,8 +3,32 @@ package common
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var DalleSizeRatios = map[string]map[string]float64{
 | 
			
		||||
	"dall-e-2": {
 | 
			
		||||
		"256x256":   1,
 | 
			
		||||
		"512x512":   1.125,
 | 
			
		||||
		"1024x1024": 1.25,
 | 
			
		||||
	},
 | 
			
		||||
	"dall-e-3": {
 | 
			
		||||
		"1024x1024": 1,
 | 
			
		||||
		"1024x1792": 2,
 | 
			
		||||
		"1792x1024": 2,
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DalleGenerationImageAmounts = map[string][2]int{
 | 
			
		||||
	"dall-e-2": {1, 10},
 | 
			
		||||
	"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DalleImagePromptLengthLimitations = map[string]int{
 | 
			
		||||
	"dall-e-2": 1000,
 | 
			
		||||
	"dall-e-3": 4000,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ModelRatio
 | 
			
		||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
 | 
			
		||||
@@ -19,11 +43,15 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-4-32k":                 30,
 | 
			
		||||
	"gpt-4-32k-0314":            30,
 | 
			
		||||
	"gpt-4-32k-0613":            30,
 | 
			
		||||
	"gpt-4-1106-preview":        5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-4-vision-preview":      5,    // $0.01 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":        0.75,
 | 
			
		||||
	"gpt-3.5-turbo-0613":        0.75,
 | 
			
		||||
	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-16k-0613":    1.5,
 | 
			
		||||
	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-1106":        0.5,  // $0.001 / 1K tokens
 | 
			
		||||
	"text-ada-001":              0.2,
 | 
			
		||||
	"text-babbage-001":          0.25,
 | 
			
		||||
	"text-curie-001":            1,
 | 
			
		||||
@@ -32,6 +60,10 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"text-davinci-edit-001":     10,
 | 
			
		||||
	"code-davinci-edit-001":     10,
 | 
			
		||||
	"whisper-1":                 15,  // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
 | 
			
		||||
	"tts-1":                     7.5, // $0.015 / 1K characters
 | 
			
		||||
	"tts-1-1106":                7.5,
 | 
			
		||||
	"tts-1-hd":                  15, // $0.030 / 1K characters
 | 
			
		||||
	"tts-1-hd-1106":             15,
 | 
			
		||||
	"davinci":                   10,
 | 
			
		||||
	"curie":                     10,
 | 
			
		||||
	"babbage":                   10,
 | 
			
		||||
@@ -40,25 +72,30 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"text-search-ada-doc-001":   10,
 | 
			
		||||
	"text-moderation-stable":    0.1,
 | 
			
		||||
	"text-moderation-latest":    0.1,
 | 
			
		||||
	"dall-e":                    8,
 | 
			
		||||
	"dall-e-2":                  8,      // $0.016 - $0.020 / image
 | 
			
		||||
	"dall-e-3":                  20,     // $0.040 - $0.120 / image
 | 
			
		||||
	"claude-instant-1":          0.815,  // $1.63 / 1M tokens
 | 
			
		||||
	"claude-2":                  5.51,   // $11.02 / 1M tokens
 | 
			
		||||
	"claude-2.0":                5.51,   // $11.02 / 1M tokens
 | 
			
		||||
	"claude-2.1":                5.51,   // $11.02 / 1M tokens
 | 
			
		||||
	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens
 | 
			
		||||
	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                    1,
 | 
			
		||||
	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens
 | 
			
		||||
	"qwen-turbo":                0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"qwen-plus":                 10,     // ¥0.14 / 1k tokens
 | 
			
		||||
	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
 | 
			
		||||
	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelRatio2JSONString() string {
 | 
			
		||||
@@ -85,9 +122,24 @@ func GetModelRatio(name string) float64 {
 | 
			
		||||
 | 
			
		||||
func GetCompletionRatio(name string) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-3.5") {
 | 
			
		||||
		if strings.HasSuffix(name, "1106") {
 | 
			
		||||
			return 2
 | 
			
		||||
		}
 | 
			
		||||
		if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
 | 
			
		||||
			// TODO: clear this after 2023-12-11
 | 
			
		||||
			now := time.Now()
 | 
			
		||||
			// https://platform.openai.com/docs/models/continuous-model-upgrades
 | 
			
		||||
			// if after 2023-12-11, use 2
 | 
			
		||||
			if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
 | 
			
		||||
				return 2
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return 1.333333
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4") {
 | 
			
		||||
		if strings.HasSuffix(name, "preview") {
 | 
			
		||||
			return 3
 | 
			
		||||
		}
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "claude-instant-1") {
 | 
			
		||||
 
 | 
			
		||||
@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
 | 
			
		||||
	return time.Now().Unix()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTimeString() string {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Max(a int, b int) int {
 | 
			
		||||
	if a >= b {
 | 
			
		||||
		return a
 | 
			
		||||
@@ -190,3 +195,15 @@ func GetOrDefault(env string, defaultValue int) int {
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MessageWithRequestId(message string, id string) string {
 | 
			
		||||
	return fmt.Sprintf("%s (request id: %s)", message, id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func String2Int(str string) int {
 | 
			
		||||
	num, err := strconv.Atoi(str)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
			Message: err.Error(),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"error": openAIError,
 | 
			
		||||
 
 | 
			
		||||
@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
 | 
			
		||||
	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.BaseURL == "" {
 | 
			
		||||
		channel.BaseURL = baseURL
 | 
			
		||||
	if channel.GetBaseURL() == "" {
 | 
			
		||||
		channel.BaseURL = &baseURL
 | 
			
		||||
	}
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeOpenAI:
 | 
			
		||||
		if channel.BaseURL != "" {
 | 
			
		||||
			baseURL = channel.BaseURL
 | 
			
		||||
		if channel.GetBaseURL() != "" {
 | 
			
		||||
			baseURL = channel.GetBaseURL()
 | 
			
		||||
		}
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	case common.ChannelTypeCustom:
 | 
			
		||||
		baseURL = channel.BaseURL
 | 
			
		||||
		baseURL = channel.GetBaseURL()
 | 
			
		||||
	case common.ChannelTypeCloseAI:
 | 
			
		||||
		return updateChannelCloseAIBalance(channel)
 | 
			
		||||
	case common.ChannelTypeOpenAISB:
 | 
			
		||||
 
 | 
			
		||||
@@ -5,13 +5,15 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 | 
			
		||||
@@ -42,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 | 
			
		||||
	}
 | 
			
		||||
	requestURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
 | 
			
		||||
		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
 | 
			
		||||
	} else {
 | 
			
		||||
		if channel.BaseURL != "" {
 | 
			
		||||
			requestURL = channel.BaseURL
 | 
			
		||||
		}
 | 
			
		||||
		requestURL += "/v1/chat/completions"
 | 
			
		||||
		if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
 | 
			
		||||
			requestURL = baseURL
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
 | 
			
		||||
	}
 | 
			
		||||
	jsonData, err := json.Marshal(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
@@ -70,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	var response TextResponse
 | 
			
		||||
	err = json.NewDecoder(resp.Body).Decode(&response)
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(body, &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
 | 
			
		||||
	}
 | 
			
		||||
	if response.Usage.CompletionTokens == 0 {
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
 | 
			
		||||
	}
 | 
			
		||||
@@ -141,7 +147,7 @@ func disableChannel(channelId int, channelName string, reason string) {
 | 
			
		||||
	if common.RootUserEmail == "" {
 | 
			
		||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
 | 
			
		||||
	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | 
			
		||||
	err := common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
 
 | 
			
		||||
@@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteDisabledChannel(c *gin.Context) {
 | 
			
		||||
	rows, err := model.DeleteDisabledChannel()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    rows,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannel(c *gin.Context) {
 | 
			
		||||
	channel := model.Channel{}
 | 
			
		||||
	err := c.ShouldBindJSON(&channel)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
 | 
			
		||||
	username := c.Query("username")
 | 
			
		||||
	tokenName := c.Query("token_name")
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
			
		||||
	channel, _ := strconv.Atoi(c.Query("channel"))
 | 
			
		||||
	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserLogs(c *gin.Context) {
 | 
			
		||||
@@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchAllLogs(c *gin.Context) {
 | 
			
		||||
	keyword := c.Query("keyword")
 | 
			
		||||
	logs, err := model.SearchAllLogs(keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUserLogs(c *gin.Context) {
 | 
			
		||||
@@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	logs, err := model.SearchUserLogs(userId, keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetLogsStat(c *gin.Context) {
 | 
			
		||||
@@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
 | 
			
		||||
	tokenName := c.Query("token_name")
 | 
			
		||||
	username := c.Query("username")
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
			
		||||
	channel, _ := strconv.Atoi(c.Query("channel"))
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 | 
			
		||||
	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data": gin.H{
 | 
			
		||||
@@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
 | 
			
		||||
			//"token": tokenNum,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
@@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
			
		||||
	tokenName := c.Query("token_name")
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
			
		||||
	channel, _ := strconv.Atoi(c.Query("channel"))
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 | 
			
		||||
	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data": gin.H{
 | 
			
		||||
@@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
			//"token": tokenNum,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteHistoryLogs(c *gin.Context) {
 | 
			
		||||
	targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
 | 
			
		||||
	if targetTimestamp == 0 {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "target timestamp is required",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	count, err := model.DeleteOldLog(targetTimestamp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    count,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -55,12 +55,21 @@ func init() {
 | 
			
		||||
	// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
	openAIModels = []OpenAIModels{
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "dall-e",
 | 
			
		||||
			Id:         "dall-e-2",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "dall-e",
 | 
			
		||||
			Root:       "dall-e-2",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "dall-e-3",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "dall-e-3",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
@@ -72,6 +81,42 @@ func init() {
 | 
			
		||||
			Root:       "whisper-1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "tts-1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "tts-1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "tts-1-1106",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "tts-1-1106",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "tts-1-hd",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "tts-1-hd",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "tts-1-hd-1106",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "tts-1-hd-1106",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -117,6 +162,24 @@ func init() {
 | 
			
		||||
			Root:       "gpt-3.5-turbo-16k-0613",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo-1106",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1699593571,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo-1106",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo-instruct",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo-instruct",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -171,6 +234,24 @@ func init() {
 | 
			
		||||
			Root:       "gpt-4-32k-0613",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-1106-preview",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1699593571,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-1106-preview",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-vision-preview",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1699593571,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-vision-preview",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-embedding-ada-002",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -265,7 +346,7 @@ func init() {
 | 
			
		||||
			Id:         "claude-instant-1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "anturopic",
 | 
			
		||||
			OwnedBy:    "anthropic",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "claude-instant-1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
@@ -274,11 +355,29 @@ func init() {
 | 
			
		||||
			Id:         "claude-2",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "anturopic",
 | 
			
		||||
			OwnedBy:    "anthropic",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "claude-2",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "claude-2.1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "anthropic",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "claude-2.1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "claude-2.0",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "anthropic",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "claude-2.0",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "ERNIE-Bot",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -297,6 +396,15 @@ func init() {
 | 
			
		||||
			Root:       "ERNIE-Bot-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "ERNIE-Bot-4",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "baidu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "ERNIE-Bot-4",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "Embedding-V1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -315,6 +423,15 @@ func init() {
 | 
			
		||||
			Root:       "PaLM-2",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "chatglm_turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "zhipu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "chatglm_turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "chatglm_pro",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -343,21 +460,21 @@ func init() {
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "qwen-v1",
 | 
			
		||||
			Id:         "qwen-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "ali",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "qwen-v1",
 | 
			
		||||
			Root:       "qwen-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "qwen-plus-v1",
 | 
			
		||||
			Id:         "qwen-plus",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "ali",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "qwen-plus-v1",
 | 
			
		||||
			Root:       "qwen-plus",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
@@ -415,12 +532,12 @@ func init() {
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "360GPT_S2_V9.4",
 | 
			
		||||
			Id:         "hunyuan",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			OwnedBy:    "tencent",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "360GPT_S2_V9.4",
 | 
			
		||||
			Root:       "hunyuan",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) {
 | 
			
		||||
		if option.Value == "true" && common.GitHubClientId == "" {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!",
 | 
			
		||||
				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
 | 
			
		||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
 | 
			
		||||
	query := ""
 | 
			
		||||
	if len(request.Messages) != 0 {
 | 
			
		||||
		query = request.Messages[len(request.Messages)-1].Content
 | 
			
		||||
		query = request.Messages[len(request.Messages)-1].StringContent()
 | 
			
		||||
	}
 | 
			
		||||
	return &AIProxyLibraryRequest{
 | 
			
		||||
		Model:  request.Model,
 | 
			
		||||
 
 | 
			
		||||
@@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 | 
			
		||||
		message := request.Messages[i]
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, AliMessage{
 | 
			
		||||
				User: message.Content,
 | 
			
		||||
				User: message.StringContent(),
 | 
			
		||||
				Bot:  "Okay",
 | 
			
		||||
			})
 | 
			
		||||
			continue
 | 
			
		||||
		} else {
 | 
			
		||||
			if i == len(request.Messages)-1 {
 | 
			
		||||
				prompt = message.Content
 | 
			
		||||
				prompt = message.StringContent()
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			messages = append(messages, AliMessage{
 | 
			
		||||
				User: message.Content,
 | 
			
		||||
				Bot:  request.Messages[i+1].Content,
 | 
			
		||||
				User: message.StringContent(),
 | 
			
		||||
				Bot:  request.Messages[i+1].StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
			i++
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,14 +2,16 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
@@ -17,18 +19,47 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	tokenName := c.GetString("token_name")
 | 
			
		||||
 | 
			
		||||
	var ttsRequest TextToSpeechRequest
 | 
			
		||||
	if relayMode == RelayModeAudioSpeech {
 | 
			
		||||
		// Read JSON
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &ttsRequest)
 | 
			
		||||
		// Check if JSON is valid
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "invalid_json", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		audioModel = ttsRequest.Model
 | 
			
		||||
		// Check if text is too long 4096
 | 
			
		||||
		if len(ttsRequest.Input) > 4096 {
 | 
			
		||||
			return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	modelRatio := common.GetModelRatio(audioModel)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	var quota int
 | 
			
		||||
	var preConsumedQuota int
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case RelayModeAudioSpeech:
 | 
			
		||||
		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
 | 
			
		||||
		quota = preConsumedQuota
 | 
			
		||||
	default:
 | 
			
		||||
		preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
 | 
			
		||||
	}
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check if user quota is enough
 | 
			
		||||
	if userQuota-preConsumedQuota < 0 {
 | 
			
		||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
	}
 | 
			
		||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
@@ -60,19 +91,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
			
		||||
		apiVersion := GetAPIVersion(c)
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	requestBody := c.Request.Body
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		req.Header.Set("api-key", apiKey)
 | 
			
		||||
		req.ContentLength = c.Request.ContentLength
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
 | 
			
		||||
@@ -89,33 +134,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	var audioResponse AudioResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		go func() {
 | 
			
		||||
			quota := countTokenText(audioResponse.Text, audioModel)
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if relayMode != RelayModeAudioSpeech {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
@@ -123,12 +144,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &audioResponse)
 | 
			
		||||
		var whisperResponse WhisperResponse
 | 
			
		||||
		err = json.Unmarshal(responseBody, &whisperResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		quota = countTokenText(whisperResponse.Text, audioModel)
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		if preConsumedQuota > 0 {
 | 
			
		||||
			// we need to roll back the pre-consumed quota
 | 
			
		||||
			defer func(ctx context.Context) {
 | 
			
		||||
				go func() {
 | 
			
		||||
					// negative means add quota back for token & user
 | 
			
		||||
					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
			}(c.Request.Context())
 | 
			
		||||
		}
 | 
			
		||||
		return relayErrorHandler(resp)
 | 
			
		||||
	}
 | 
			
		||||
	quotaDelta := quota - preConsumedQuota
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
 
 | 
			
		||||
@@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, BaiduMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, BaiduMessage{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
@@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, BaiduMessage{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -70,7 +70,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
		} else if message.Role == "assistant" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 | 
			
		||||
		} else if message.Role == "system" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
 | 
			
		||||
			if prompt == "" {
 | 
			
		||||
				prompt = message.StringContent()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	prompt += "\n\nAssistant:"
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -9,40 +10,76 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func isWithinRange(element string, value int) bool {
 | 
			
		||||
	if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	min := common.DalleGenerationImageAmounts[element][0]
 | 
			
		||||
	max := common.DalleGenerationImageAmounts[element][1]
 | 
			
		||||
 | 
			
		||||
	return value >= min && value <= max
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	imageModel := "dall-e"
 | 
			
		||||
	imageModel := "dall-e-2"
 | 
			
		||||
	imageSize := "1024x1024"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
 | 
			
		||||
	var imageRequest ImageRequest
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, &imageRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Size validation
 | 
			
		||||
	if imageRequest.Size != "" {
 | 
			
		||||
		imageSize = imageRequest.Size
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Model validation
 | 
			
		||||
	if imageRequest.Model != "" {
 | 
			
		||||
		imageModel = imageRequest.Model
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
 | 
			
		||||
 | 
			
		||||
	// Check if model is supported
 | 
			
		||||
	if hasValidSize {
 | 
			
		||||
		if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
 | 
			
		||||
			if imageSize == "1024x1024" {
 | 
			
		||||
				imageCostRatio *= 2
 | 
			
		||||
			} else {
 | 
			
		||||
				imageCostRatio *= 1.5
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Prompt validation
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
 | 
			
		||||
		return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Not "256x256", "512x512", or "1024x1024"
 | 
			
		||||
	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
 | 
			
		||||
		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
	// Check prompt length
 | 
			
		||||
	if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
 | 
			
		||||
		return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// N should between 1 and 10
 | 
			
		||||
	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
 | 
			
		||||
		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
	// Number of generated images validation
 | 
			
		||||
	if isWithinRange(imageModel, imageRequest.N) == false {
 | 
			
		||||
		return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
@@ -59,18 +96,21 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
			isModelMapped = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
 | 
			
		||||
		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
 | 
			
		||||
		apiVersion := GetAPIVersion(c)
 | 
			
		||||
		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var requestBody io.Reader
 | 
			
		||||
	if isModelMapped {
 | 
			
		||||
	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
 | 
			
		||||
		jsonStr, err := json.Marshal(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
@@ -85,26 +125,23 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
			
		||||
 | 
			
		||||
	sizeRatio := 1.0
 | 
			
		||||
	// Size
 | 
			
		||||
	if imageRequest.Size == "256x256" {
 | 
			
		||||
		sizeRatio = 1
 | 
			
		||||
	} else if imageRequest.Size == "512x512" {
 | 
			
		||||
		sizeRatio = 1.125
 | 
			
		||||
	} else if imageRequest.Size == "1024x1024" {
 | 
			
		||||
		sizeRatio = 1.25
 | 
			
		||||
	}
 | 
			
		||||
	quota := int(ratio*sizeRatio*1000) * imageRequest.N
 | 
			
		||||
	quota := int(ratio*imageCostRatio*1000) * imageRequest.N
 | 
			
		||||
 | 
			
		||||
	if consumeQuota && userQuota-quota < 0 {
 | 
			
		||||
		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
	if userQuota-quota < 0 {
 | 
			
		||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
	token := c.Request.Header.Get("Authorization")
 | 
			
		||||
	if channelType == common.ChannelTypeAzure { // Azure authentication
 | 
			
		||||
		token = strings.TrimPrefix(token, "Bearer ")
 | 
			
		||||
		req.Header.Set("api-key", token)
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Header.Set("Authorization", token)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
@@ -124,8 +161,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse ImageResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		err := model.PostConsumeTokenQuota(tokenId, quota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
@@ -137,15 +173,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
		if quota != 0 {
 | 
			
		||||
			tokenName := c.GetString("token_name")
 | 
			
		||||
			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
			model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
			model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
		}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -161,7 +195,6 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
 
 | 
			
		||||
@@ -88,9 +88,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
@@ -111,7 +110,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
 | 
			
		||||
	}
 | 
			
		||||
	// Reset response body
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
	// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
	// So the httpClient will be confused by the response.
 | 
			
		||||
@@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err := io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
@@ -132,7 +131,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 {
 | 
			
		||||
		completionTokens := 0
 | 
			
		||||
		for _, choice := range textResponse.Choices {
 | 
			
		||||
			completionTokens += countTokenText(choice.Message.Content, model)
 | 
			
		||||
			completionTokens += countTokenText(choice.Message.StringContent(), model)
 | 
			
		||||
		}
 | 
			
		||||
		textResponse.Usage = Usage{
 | 
			
		||||
			PromptTokens:     promptTokens,
 | 
			
		||||
 
 | 
			
		||||
@@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
 | 
			
		||||
	}
 | 
			
		||||
	for _, message := range textRequest.Messages {
 | 
			
		||||
		palmMessage := PaLMChatMessage{
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
			Content: message.StringContent(),
 | 
			
		||||
		}
 | 
			
		||||
		if message.Role == "user" {
 | 
			
		||||
			palmMessage.Author = "0"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,287 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha1"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.tencent.com/document/product/1729/97732
 | 
			
		||||
 | 
			
		||||
type TencentMessage struct {
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TencentChatRequest struct {
 | 
			
		||||
	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID
 | 
			
		||||
	SecretId string `json:"secret_id"` // 官网 SecretId
 | 
			
		||||
	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
 | 
			
		||||
	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
 | 
			
		||||
	Timestamp int64 `json:"timestamp"`
 | 
			
		||||
	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
 | 
			
		||||
	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
 | 
			
		||||
	Expired int64  `json:"expired"`
 | 
			
		||||
	QueryID string `json:"query_id"` //请求 Id,用于问题排查
 | 
			
		||||
	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
 | 
			
		||||
	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
 | 
			
		||||
	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
 | 
			
		||||
	Temperature float64 `json:"temperature"`
 | 
			
		||||
	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
 | 
			
		||||
	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
 | 
			
		||||
	// 建议该参数和 temperature 只设置1个,不要同时更改
 | 
			
		||||
	TopP float64 `json:"top_p"`
 | 
			
		||||
	// Stream 0:同步,1:流式 (默认,协议:SSE)
 | 
			
		||||
	// 同步请求超时:60s,如果内容较长建议使用流式
 | 
			
		||||
	Stream int `json:"stream"`
 | 
			
		||||
	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
 | 
			
		||||
	// 输入 content 总数最大支持 3000 token。
 | 
			
		||||
	Messages []TencentMessage `json:"messages"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TencentError struct {
 | 
			
		||||
	Code    int    `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TencentUsage struct {
 | 
			
		||||
	InputTokens  int `json:"input_tokens"`
 | 
			
		||||
	OutputTokens int `json:"output_tokens"`
 | 
			
		||||
	TotalTokens  int `json:"total_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TencentResponseChoices struct {
 | 
			
		||||
	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
 | 
			
		||||
	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
 | 
			
		||||
	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TencentChatResponse struct {
 | 
			
		||||
	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
 | 
			
		||||
	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串
 | 
			
		||||
	Id      string                   `json:"id,omitempty"`      // 会话 id
 | 
			
		||||
	Usage   Usage                    `json:"usage,omitempty"`   // token 数量
 | 
			
		||||
	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值
 | 
			
		||||
	Note    string                   `json:"note,omitempty"`    // 注释
 | 
			
		||||
	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
 | 
			
		||||
	messages := make([]TencentMessage, 0, len(request.Messages))
 | 
			
		||||
	for i := 0; i < len(request.Messages); i++ {
 | 
			
		||||
		message := request.Messages[i]
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, TencentMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, TencentMessage{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
				Content: "Okay",
 | 
			
		||||
			})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		messages = append(messages, TencentMessage{
 | 
			
		||||
			Content: message.StringContent(),
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	stream := 0
 | 
			
		||||
	if request.Stream {
 | 
			
		||||
		stream = 1
 | 
			
		||||
	}
 | 
			
		||||
	return &TencentChatRequest{
 | 
			
		||||
		Timestamp:   common.GetTimestamp(),
 | 
			
		||||
		Expired:     common.GetTimestamp() + 24*60*60,
 | 
			
		||||
		QueryID:     common.GetUUID(),
 | 
			
		||||
		Temperature: request.Temperature,
 | 
			
		||||
		TopP:        request.TopP,
 | 
			
		||||
		Stream:      stream,
 | 
			
		||||
		Messages:    messages,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Usage:   response.Usage,
 | 
			
		||||
	}
 | 
			
		||||
	if len(response.Choices) > 0 {
 | 
			
		||||
		choice := OpenAITextResponseChoice{
 | 
			
		||||
			Index: 0,
 | 
			
		||||
			Message: Message{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
				Content: response.Choices[0].Messages.Content,
 | 
			
		||||
			},
 | 
			
		||||
			FinishReason: response.Choices[0].FinishReason,
 | 
			
		||||
		}
 | 
			
		||||
		fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	response := ChatCompletionsStreamResponse{
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "tencent-hunyuan",
 | 
			
		||||
	}
 | 
			
		||||
	if len(TencentResponse.Choices) > 0 {
 | 
			
		||||
		var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
 | 
			
		||||
		if TencentResponse.Choices[0].FinishReason == "stop" {
 | 
			
		||||
			choice.FinishReason = &stopFinishReason
 | 
			
		||||
		}
 | 
			
		||||
		response.Choices = append(response.Choices, choice)
 | 
			
		||||
	}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
			
		||||
	var responseText string
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var TencentResponse TencentChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &TencentResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseTencent2OpenAI(&TencentResponse)
 | 
			
		||||
			if len(response.Choices) != 0 {
 | 
			
		||||
				responseText += response.Choices[0].Delta.Content
 | 
			
		||||
			}
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var TencentResponse TencentChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &TencentResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if TencentResponse.Error.Code != 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: TencentResponse.Error.Message,
 | 
			
		||||
				Code:    TencentResponse.Error.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseTencent2OpenAI(&TencentResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
 | 
			
		||||
	parts := strings.Split(config, "|")
 | 
			
		||||
	if len(parts) != 3 {
 | 
			
		||||
		err = errors.New("invalid tencent config")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	appId, err = strconv.ParseInt(parts[0], 10, 64)
 | 
			
		||||
	secretId = parts[1]
 | 
			
		||||
	secretKey = parts[2]
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTencentSign(req TencentChatRequest, secretKey string) string {
 | 
			
		||||
	params := make([]string, 0)
 | 
			
		||||
	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
 | 
			
		||||
	params = append(params, "secret_id="+req.SecretId)
 | 
			
		||||
	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
 | 
			
		||||
	params = append(params, "query_id="+req.QueryID)
 | 
			
		||||
	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
 | 
			
		||||
	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
 | 
			
		||||
	params = append(params, "stream="+strconv.Itoa(req.Stream))
 | 
			
		||||
	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
 | 
			
		||||
 | 
			
		||||
	var messageStr string
 | 
			
		||||
	for _, msg := range req.Messages {
 | 
			
		||||
		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
 | 
			
		||||
	}
 | 
			
		||||
	messageStr = strings.TrimSuffix(messageStr, ",")
 | 
			
		||||
	params = append(params, "messages=["+messageStr+"]")
 | 
			
		||||
 | 
			
		||||
	sort.Sort(sort.StringSlice(params))
 | 
			
		||||
	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
 | 
			
		||||
	mac := hmac.New(sha1.New, []byte(secretKey))
 | 
			
		||||
	signURL := url
 | 
			
		||||
	mac.Write([]byte(signURL))
 | 
			
		||||
	sign := mac.Sum([]byte(nil))
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(sign)
 | 
			
		||||
}
 | 
			
		||||
@@ -2,16 +2,19 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -23,13 +26,21 @@ const (
 | 
			
		||||
	APITypeAli
 | 
			
		||||
	APITypeXunfei
 | 
			
		||||
	APITypeAIProxyLibrary
 | 
			
		||||
	APITypeTencent
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var httpClient *http.Client
 | 
			
		||||
var impatientHTTPClient *http.Client
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	if common.RelayTimeout == 0 {
 | 
			
		||||
		httpClient = &http.Client{}
 | 
			
		||||
	} else {
 | 
			
		||||
		httpClient = &http.Client{
 | 
			
		||||
			Timeout: time.Duration(common.RelayTimeout) * time.Second,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	impatientHTTPClient = &http.Client{
 | 
			
		||||
		Timeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
@@ -37,17 +48,15 @@ func init() {
 | 
			
		||||
 | 
			
		||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	var textRequest GeneralOpenAIRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 | 
			
		||||
	err := common.UnmarshalBodyReusable(c, &textRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	}
 | 
			
		||||
	if relayMode == RelayModeModerations && textRequest.Model == "" {
 | 
			
		||||
		textRequest.Model = "text-moderation-latest"
 | 
			
		||||
	}
 | 
			
		||||
@@ -107,22 +116,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		apiType = APITypeXunfei
 | 
			
		||||
	case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
		apiType = APITypeAIProxyLibrary
 | 
			
		||||
	case common.ChannelTypeTencent:
 | 
			
		||||
		apiType = APITypeTencent
 | 
			
		||||
	}
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if channelType == common.ChannelTypeAzure {
 | 
			
		||||
			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 | 
			
		||||
			query := c.Request.URL.Query()
 | 
			
		||||
			apiVersion := query.Get("api-version")
 | 
			
		||||
			if apiVersion == "" {
 | 
			
		||||
				apiVersion = c.GetString("api_version")
 | 
			
		||||
			}
 | 
			
		||||
			apiVersion := GetAPIVersion(c)
 | 
			
		||||
			requestURL := strings.Split(requestURL, "?")[0]
 | 
			
		||||
			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
 | 
			
		||||
			baseURL = c.GetString("base_url")
 | 
			
		||||
@@ -133,7 +140,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			model_ = strings.TrimSuffix(model_, "-0301")
 | 
			
		||||
			model_ = strings.TrimSuffix(model_, "-0314")
 | 
			
		||||
			model_ = strings.TrimSuffix(model_, "-0613")
 | 
			
		||||
			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 | 
			
		||||
 | 
			
		||||
			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
 | 
			
		||||
			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
		fullRequestURL = "https://api.anthropic.com/v1/complete"
 | 
			
		||||
@@ -146,6 +155,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
 | 
			
		||||
		case "ERNIE-Bot-turbo":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
 | 
			
		||||
		case "ERNIE-Bot-4":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
 | 
			
		||||
		case "BLOOMZ-7B":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
 | 
			
		||||
		case "Embedding-V1":
 | 
			
		||||
@@ -177,6 +188,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		if relayMode == RelayModeEmbeddings {
 | 
			
		||||
			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeTencent:
 | 
			
		||||
		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 | 
			
		||||
	}
 | 
			
		||||
@@ -202,6 +215,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	if userQuota-preConsumedQuota < 0 {
 | 
			
		||||
		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
	}
 | 
			
		||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
@@ -210,8 +226,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		// in this case, we do not pre-consume quota
 | 
			
		||||
		// because the user has enough quota
 | 
			
		||||
		preConsumedQuota = 0
 | 
			
		||||
		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 | 
			
		||||
	}
 | 
			
		||||
	if consumeQuota && preConsumedQuota > 0 {
 | 
			
		||||
	if preConsumedQuota > 0 {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
			
		||||
@@ -279,6 +296,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeTencent:
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		appId, secretId, secretKey, err := parseTencentConfig(apiKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		tencentRequest := requestOpenAI2Tencent(textRequest)
 | 
			
		||||
		tencentRequest.AppId = appId
 | 
			
		||||
		tencentRequest.SecretId = secretId
 | 
			
		||||
		jsonStr, err := json.Marshal(tencentRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		sign := getTencentSign(*tencentRequest, secretKey)
 | 
			
		||||
		c.Request.Header.Set("Authorization", sign)
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
 | 
			
		||||
		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
 | 
			
		||||
@@ -326,11 +360,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if textRequest.Stream {
 | 
			
		||||
				req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
			}
 | 
			
		||||
		case APITypeTencent:
 | 
			
		||||
			req.Header.Set("Authorization", apiKey)
 | 
			
		||||
		case APITypePaLM:
 | 
			
		||||
			// do not set Authorization header
 | 
			
		||||
		default:
 | 
			
		||||
			req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
			
		||||
		}
 | 
			
		||||
		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
		if isStream && c.Request.Header.Get("Accept") == "" {
 | 
			
		||||
			req.Header.Set("Accept", "text/event-stream")
 | 
			
		||||
		}
 | 
			
		||||
		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
			
		||||
		resp, err = httpClient.Do(req)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -348,13 +389,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
			if preConsumedQuota != 0 {
 | 
			
		||||
				go func() {
 | 
			
		||||
				go func(ctx context.Context) {
 | 
			
		||||
					// return pre-consumed quota
 | 
			
		||||
					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error return pre-consumed quota: " + err.Error())
 | 
			
		||||
						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
				}(c.Request.Context())
 | 
			
		||||
			}
 | 
			
		||||
			return relayErrorHandler(resp)
 | 
			
		||||
		}
 | 
			
		||||
@@ -362,19 +403,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	tokenName := c.GetString("token_name")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		// c.Writer.Flush()
 | 
			
		||||
		go func() {
 | 
			
		||||
			if consumeQuota {
 | 
			
		||||
			quota := 0
 | 
			
		||||
			completionRatio := common.GetCompletionRatio(textRequest.Model)
 | 
			
		||||
			promptTokens = textResponse.Usage.PromptTokens
 | 
			
		||||
			completionTokens = textResponse.Usage.CompletionTokens
 | 
			
		||||
 | 
			
		||||
				quota = promptTokens + int(float64(completionTokens)*completionRatio)
 | 
			
		||||
				quota = int(float64(quota) * ratio)
 | 
			
		||||
			quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
 | 
			
		||||
			if ratio != 0 && quota <= 0 {
 | 
			
		||||
				quota = 1
 | 
			
		||||
			}
 | 
			
		||||
@@ -387,21 +424,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
					common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
				common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
					common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
				common.LogError(ctx, "error update user quota cache: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if isStream {
 | 
			
		||||
@@ -413,7 +450,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
 | 
			
		||||
			err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
@@ -539,14 +576,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeXunfei:
 | 
			
		||||
		if isStream {
 | 
			
		||||
		auth := c.Request.Header.Get("Authorization")
 | 
			
		||||
		auth = strings.TrimPrefix(auth, "Bearer ")
 | 
			
		||||
		splits := strings.Split(auth, "|")
 | 
			
		||||
		if len(splits) != 3 {
 | 
			
		||||
			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
		var err *OpenAIErrorWithStatusCode
 | 
			
		||||
		var usage *Usage
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@@ -554,9 +596,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			textResponse.Usage = *usage
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := aiProxyLibraryStreamHandler(c, resp)
 | 
			
		||||
@@ -577,6 +616,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeTencent:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, responseText := tencentStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			textResponse.Usage.PromptTokens = promptTokens
 | 
			
		||||
			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := tencentHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,53 +1,65 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var stopFinishReason = "stop"
 | 
			
		||||
 | 
			
		||||
// tokenEncoderMap won't grow after initialization
 | 
			
		||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
			
		||||
var defaultTokenEncoder *tiktoken.Tiktoken
 | 
			
		||||
 | 
			
		||||
func InitTokenEncoders() {
 | 
			
		||||
	common.SysLog("initializing token encoders")
 | 
			
		||||
	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
 | 
			
		||||
		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
	defaultTokenEncoder = gpt35TokenEncoder
 | 
			
		||||
	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
	for model, _ := range common.ModelRatio {
 | 
			
		||||
		tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
 | 
			
		||||
			tokenEncoderMap[model] = fallbackTokenEncoder
 | 
			
		||||
			continue
 | 
			
		||||
		if strings.HasPrefix(model, "gpt-3.5") {
 | 
			
		||||
			tokenEncoderMap[model] = gpt35TokenEncoder
 | 
			
		||||
		} else if strings.HasPrefix(model, "gpt-4") {
 | 
			
		||||
			tokenEncoderMap[model] = gpt4TokenEncoder
 | 
			
		||||
		} else {
 | 
			
		||||
			tokenEncoderMap[model] = nil
 | 
			
		||||
		}
 | 
			
		||||
		tokenEncoderMap[model] = tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("token encoders initialized")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
			
		||||
	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
			
		||||
	tokenEncoder, ok := tokenEncoderMap[model]
 | 
			
		||||
	if ok && tokenEncoder != nil {
 | 
			
		||||
		return tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	if ok {
 | 
			
		||||
		tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
 | 
			
		||||
		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
 | 
			
		||||
		}
 | 
			
		||||
			tokenEncoder = defaultTokenEncoder
 | 
			
		||||
		}
 | 
			
		||||
		tokenEncoderMap[model] = tokenEncoder
 | 
			
		||||
		return tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	return defaultTokenEncoder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 | 
			
		||||
	if common.ApproximateTokenEnabled {
 | 
			
		||||
@@ -75,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int {
 | 
			
		||||
	tokenNum := 0
 | 
			
		||||
	for _, message := range messages {
 | 
			
		||||
		tokenNum += tokensPerMessage
 | 
			
		||||
		tokenNum += getTokenNum(tokenEncoder, message.Content)
 | 
			
		||||
		tokenNum += getTokenNum(tokenEncoder, message.StringContent())
 | 
			
		||||
		tokenNum += getTokenNum(tokenEncoder, message.Role)
 | 
			
		||||
		if message.Name != nil {
 | 
			
		||||
			tokenNum += tokensPerName
 | 
			
		||||
@@ -146,7 +158,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 | 
			
		||||
		StatusCode: resp.StatusCode,
 | 
			
		||||
		OpenAIError: OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
			Code:    "bad_response_status_code",
 | 
			
		||||
			Param:   strconv.Itoa(resp.StatusCode),
 | 
			
		||||
		},
 | 
			
		||||
@@ -167,3 +179,48 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 | 
			
		||||
	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
 | 
			
		||||
		switch channelType {
 | 
			
		||||
		case common.ChannelTypeOpenAI:
 | 
			
		||||
			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
 | 
			
		||||
		case common.ChannelTypeAzure:
 | 
			
		||||
			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return fullRequestURL
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
 | 
			
		||||
	// quotaDelta is remaining quota to be consumed
 | 
			
		||||
	err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	// totalQuota is total quota consumed
 | 
			
		||||
	if totalQuota != 0 {
 | 
			
		||||
		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
 | 
			
		||||
		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
 | 
			
		||||
		model.UpdateChannelUsedQuota(channelId, totalQuota)
 | 
			
		||||
	}
 | 
			
		||||
	if totalQuota <= 0 {
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAPIVersion(c *gin.Context) string {
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	}
 | 
			
		||||
	return apiVersion
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, XunfeiMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, XunfeiMessage{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
@@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, XunfeiMessage{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: response.Payload.Choices.Text[0].Content,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: stopFinishReason,
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
@@ -177,33 +178,85 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
			
		||||
	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse := <-dataChan:
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
			
		||||
			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = "v1.1"
 | 
			
		||||
		common.SysLog("api_version not found, use default: " + apiVersion)
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	domain := "general"
 | 
			
		||||
	if apiVersion == "v2.1" {
 | 
			
		||||
		domain = "generalv2"
 | 
			
		||||
	})
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
 | 
			
		||||
 | 
			
		||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
			
		||||
	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	var content string
 | 
			
		||||
	var xunfeiResponse XunfeiChatResponse
 | 
			
		||||
	stop := false
 | 
			
		||||
	for !stop {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse = <-dataChan:
 | 
			
		||||
			if len(xunfeiResponse.Payload.Choices.Text) == 0 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			content += xunfeiResponse.Payload.Choices.Text[0].Content
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
			
		||||
		case stop = <-stopChan:
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	xunfeiResponse.Payload.Choices.Text[0].Content = content
 | 
			
		||||
 | 
			
		||||
	response := responseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	_, _ = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
 | 
			
		||||
	conn, resp, err := d.Dial(authUrl, nil)
 | 
			
		||||
	if err != nil || resp.StatusCode != 101 {
 | 
			
		||||
		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 | 
			
		||||
	err = conn.WriteJSON(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dataChan := make(chan XunfeiChatResponse)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
@@ -230,61 +283,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse := <-dataChan:
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
			
		||||
			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
 | 
			
		||||
	return dataChan, stopChan, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var xunfeiResponse XunfeiChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = "v1.1"
 | 
			
		||||
		common.SysLog("api_version not found, use default: " + apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &xunfeiResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	domain := "general"
 | 
			
		||||
	if apiVersion != "v1.1" {
 | 
			
		||||
		domain += strings.Split(apiVersion, ".")[0]
 | 
			
		||||
	}
 | 
			
		||||
	if xunfeiResponse.Header.Code != 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: xunfeiResponse.Header.Message,
 | 
			
		||||
				Type:    "xunfei_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    xunfeiResponse.Header.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
			
		||||
	return domain, authUrl
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, ZhipuMessage{
 | 
			
		||||
				Role:    "system",
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, ZhipuMessage{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
@@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, ZhipuMessage{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.Content,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,10 +12,49 @@ import (
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role    string  `json:"role"`
 | 
			
		||||
	Content string  `json:"content"`
 | 
			
		||||
	Content any     `json:"content"`
 | 
			
		||||
	Name    *string `json:"name,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageURL struct {
 | 
			
		||||
	Url    string `json:"url,omitempty"`
 | 
			
		||||
	Detail string `json:"detail,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TextContent struct {
 | 
			
		||||
	Type string `json:"type,omitempty"`
 | 
			
		||||
	Text string `json:"text,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageContent struct {
 | 
			
		||||
	Type     string    `json:"type,omitempty"`
 | 
			
		||||
	ImageURL *ImageURL `json:"image_url,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Message) StringContent() string {
 | 
			
		||||
	content, ok := m.Content.(string)
 | 
			
		||||
	if ok {
 | 
			
		||||
		return content
 | 
			
		||||
	}
 | 
			
		||||
	contentList, ok := m.Content.([]any)
 | 
			
		||||
	if ok {
 | 
			
		||||
		var contentStr string
 | 
			
		||||
		for _, contentItem := range contentList {
 | 
			
		||||
			contentMap, ok := contentItem.(map[string]any)
 | 
			
		||||
			if !ok {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if contentMap["type"] == "text" {
 | 
			
		||||
				if subStr, ok := contentMap["text"].(string); ok {
 | 
			
		||||
					contentStr += subStr
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return contentStr
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RelayModeUnknown = iota
 | 
			
		||||
	RelayModeChatCompletions
 | 
			
		||||
@@ -24,11 +63,17 @@ const (
 | 
			
		||||
	RelayModeModerations
 | 
			
		||||
	RelayModeImagesGenerations
 | 
			
		||||
	RelayModeEdits
 | 
			
		||||
	RelayModeAudio
 | 
			
		||||
	RelayModeAudioSpeech
 | 
			
		||||
	RelayModeAudioTranscription
 | 
			
		||||
	RelayModeAudioTranslation
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/chat
 | 
			
		||||
 | 
			
		||||
type ResponseFormat struct {
 | 
			
		||||
	Type string `json:"type,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GeneralOpenAIRequest struct {
 | 
			
		||||
	Model            string          `json:"model,omitempty"`
 | 
			
		||||
	Messages         []Message       `json:"messages,omitempty"`
 | 
			
		||||
@@ -42,6 +87,13 @@ type GeneralOpenAIRequest struct {
 | 
			
		||||
	Instruction      string          `json:"instruction,omitempty"`
 | 
			
		||||
	Size             string          `json:"size,omitempty"`
 | 
			
		||||
	Functions        any             `json:"functions,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64         `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	PresencePenalty  float64         `json:"presence_penalty,omitempty"`
 | 
			
		||||
	ResponseFormat   *ResponseFormat `json:"response_format,omitempty"`
 | 
			
		||||
	Seed             float64         `json:"seed,omitempty"`
 | 
			
		||||
	Tools            any             `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice       any             `json:"tool_choice,omitempty"`
 | 
			
		||||
	User             string          `json:"user,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
			
		||||
@@ -77,16 +129,30 @@ type TextRequest struct {
 | 
			
		||||
	//Stream   bool      `json:"stream"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
 | 
			
		||||
type ImageRequest struct {
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	N      int    `json:"n"`
 | 
			
		||||
	Size   string `json:"size"`
 | 
			
		||||
	Model          string `json:"model"`
 | 
			
		||||
	Prompt         string `json:"prompt" binding:"required"`
 | 
			
		||||
	N              int    `json:"n,omitempty"`
 | 
			
		||||
	Size           string `json:"size,omitempty"`
 | 
			
		||||
	Quality        string `json:"quality,omitempty"`
 | 
			
		||||
	ResponseFormat string `json:"response_format,omitempty"`
 | 
			
		||||
	Style          string `json:"style,omitempty"`
 | 
			
		||||
	User           string `json:"user,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AudioResponse struct {
 | 
			
		||||
type WhisperResponse struct {
 | 
			
		||||
	Text string `json:"text,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TextToSpeechRequest struct {
 | 
			
		||||
	Model          string  `json:"model" binding:"required"`
 | 
			
		||||
	Input          string  `json:"input" binding:"required"`
 | 
			
		||||
	Voice          string  `json:"voice" binding:"required"`
 | 
			
		||||
	Speed          float64 `json:"speed"`
 | 
			
		||||
	ResponseFormat string  `json:"response_format"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Usage struct {
 | 
			
		||||
	PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
	CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
@@ -183,19 +249,28 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		relayMode = RelayModeImagesGenerations
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 | 
			
		||||
		relayMode = RelayModeEdits
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
		relayMode = RelayModeAudio
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
 | 
			
		||||
		relayMode = RelayModeAudioSpeech
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
			
		||||
		relayMode = RelayModeAudioTranscription
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
 | 
			
		||||
		relayMode = RelayModeAudioTranslation
 | 
			
		||||
	}
 | 
			
		||||
	var err *OpenAIErrorWithStatusCode
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case RelayModeImagesGenerations:
 | 
			
		||||
		err = relayImageHelper(c, relayMode)
 | 
			
		||||
	case RelayModeAudio:
 | 
			
		||||
	case RelayModeAudioSpeech:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case RelayModeAudioTranslation:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case RelayModeAudioTranscription:
 | 
			
		||||
		err = relayAudioHelper(c, relayMode)
 | 
			
		||||
	default:
 | 
			
		||||
		err = relayTextHelper(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		requestId := c.GetString(common.RequestIdKey)
 | 
			
		||||
		retryTimesStr := c.Query("retry")
 | 
			
		||||
		retryTimes, _ := strconv.Atoi(retryTimesStr)
 | 
			
		||||
		if retryTimesStr == "" {
 | 
			
		||||
@@ -207,12 +282,13 @@ func Relay(c *gin.Context) {
 | 
			
		||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
			}
 | 
			
		||||
			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 | 
			
		||||
			c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
				"error": err.OpenAIError,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
 
 | 
			
		||||
@@ -9,21 +9,21 @@ services:
 | 
			
		||||
    ports:
 | 
			
		||||
      - "3000:3000"
 | 
			
		||||
    volumes:
 | 
			
		||||
      - ./data:/data
 | 
			
		||||
      - ./data/oneapi:/data
 | 
			
		||||
      - ./logs:/app/logs
 | 
			
		||||
    environment:
 | 
			
		||||
      - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库
 | 
			
		||||
      - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api  # 修改此行,或注释掉以使用 SQLite 作为数据库
 | 
			
		||||
      - REDIS_CONN_STRING=redis://redis
 | 
			
		||||
      - SESSION_SECRET=random_string  # 修改为随机字符串
 | 
			
		||||
      - TZ=Asia/Shanghai
 | 
			
		||||
#      - NODE_TYPE=slave  # 多机部署时从节点取消注释该行
 | 
			
		||||
#      - SYNC_FREQUENCY=60  # 需要定期从数据库加载数据时取消注释该行
 | 
			
		||||
#      - FRONTEND_BASE_URL=https://openai.justsong.cn  # 多机部署时从节点取消注释该行
 | 
			
		||||
 | 
			
		||||
    depends_on:
 | 
			
		||||
      - redis
 | 
			
		||||
      - db
 | 
			
		||||
    healthcheck:
 | 
			
		||||
      test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
 | 
			
		||||
      test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
 | 
			
		||||
      interval: 30s
 | 
			
		||||
      timeout: 10s
 | 
			
		||||
      retries: 3
 | 
			
		||||
@@ -32,3 +32,18 @@ services:
 | 
			
		||||
    image: redis:latest
 | 
			
		||||
    container_name: redis
 | 
			
		||||
    restart: always
 | 
			
		||||
 | 
			
		||||
  db:
 | 
			
		||||
    image: mysql:8.2.0
 | 
			
		||||
    restart: always
 | 
			
		||||
    container_name: mysql
 | 
			
		||||
    volumes:
 | 
			
		||||
      - ./data/mysql:/var/lib/mysql  # 挂载目录,持久化存储
 | 
			
		||||
    ports:
 | 
			
		||||
      - '3306:3306'
 | 
			
		||||
    environment:
 | 
			
		||||
      TZ: Asia/Shanghai   # 设置时区
 | 
			
		||||
      MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
 | 
			
		||||
      MYSQL_USER: oneapi   # 创建专用用户
 | 
			
		||||
      MYSQL_PASSWORD: '123456'    # 设置专用用户密码
 | 
			
		||||
      MYSQL_DATABASE: one-api   # 自动创建数据库
 | 
			
		||||
							
								
								
									
										10
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.mod
									
									
									
									
									
								
							@@ -15,8 +15,9 @@ require (
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.0
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.5
 | 
			
		||||
	golang.org/x/crypto v0.9.0
 | 
			
		||||
	golang.org/x/crypto v0.14.0
 | 
			
		||||
	gorm.io/driver/mysql v1.4.3
 | 
			
		||||
	gorm.io/driver/postgres v1.5.2
 | 
			
		||||
	gorm.io/driver/sqlite v1.4.3
 | 
			
		||||
	gorm.io/gorm v1.25.0
 | 
			
		||||
)
 | 
			
		||||
@@ -52,10 +53,9 @@ require (
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.11 // indirect
 | 
			
		||||
	golang.org/x/arch v0.3.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.10.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.9.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.17.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.13.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.13.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.30.0 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
	gorm.io/driver/postgres v1.5.2 // indirect
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										17
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								go.sum
									
									
									
									
									
								
							@@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
 | 
			
		||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
 | 
			
		||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
 | 
			
		||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
 | 
			
		||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
 | 
			
		||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
 | 
			
		||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 | 
			
		||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
 | 
			
		||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 | 
			
		||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
 | 
			
		||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
 | 
			
		||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
@@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
 | 
			
		||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
 | 
			
		||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
 | 
			
		||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 | 
			
		||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 | 
			
		||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
 | 
			
		||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
 | 
			
		||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
 | 
			
		||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 | 
			
		||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
@@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
 | 
			
		||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
 | 
			
		||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
 | 
			
		||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
 | 
			
		||||
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
 | 
			
		||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
 | 
			
		||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
 | 
			
		||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										29
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								main.go
									
									
									
									
									
								
							@@ -2,6 +2,7 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"embed"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-contrib/sessions/cookie"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -21,7 +22,7 @@ var buildFS embed.FS
 | 
			
		||||
var indexPage []byte
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	common.SetupGinLog()
 | 
			
		||||
	common.SetupLogger()
 | 
			
		||||
	common.SysLog("One API " + common.Version + " started")
 | 
			
		||||
	if os.Getenv("GIN_MODE") != "debug" {
 | 
			
		||||
		gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
@@ -50,18 +51,17 @@ func main() {
 | 
			
		||||
	// Initialize options
 | 
			
		||||
	model.InitOptionMap()
 | 
			
		||||
	if common.RedisEnabled {
 | 
			
		||||
		// for compatibility with old versions
 | 
			
		||||
		common.MemoryCacheEnabled = true
 | 
			
		||||
	}
 | 
			
		||||
	if common.MemoryCacheEnabled {
 | 
			
		||||
		common.SysLog("memory cache enabled")
 | 
			
		||||
		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
 | 
			
		||||
		model.InitChannelCache()
 | 
			
		||||
	}
 | 
			
		||||
	if os.Getenv("SYNC_FREQUENCY") != "" {
 | 
			
		||||
		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		common.SyncFrequency = frequency
 | 
			
		||||
		go model.SyncOptions(frequency)
 | 
			
		||||
		if common.RedisEnabled {
 | 
			
		||||
			go model.SyncChannelCache(frequency)
 | 
			
		||||
		}
 | 
			
		||||
	if common.MemoryCacheEnabled {
 | 
			
		||||
		go model.SyncOptions(common.SyncFrequency)
 | 
			
		||||
		go model.SyncChannelCache(common.SyncFrequency)
 | 
			
		||||
	}
 | 
			
		||||
	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
 | 
			
		||||
		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
 | 
			
		||||
@@ -85,11 +85,12 @@ func main() {
 | 
			
		||||
	controller.InitTokenEncoders()
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
	server := gin.New()
 | 
			
		||||
	server.Use(gin.Recovery())
 | 
			
		||||
	// This will cause SSE not to work!!!
 | 
			
		||||
	//server.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	server.Use(middleware.CORS())
 | 
			
		||||
 | 
			
		||||
	server.Use(middleware.RequestId())
 | 
			
		||||
	middleware.SetUpLogger(server)
 | 
			
		||||
	// Initialize session store
 | 
			
		||||
	store := cookie.NewStore([]byte(common.SessionSecret))
 | 
			
		||||
	server.Use(sessions.Sessions("session", store))
 | 
			
		||||
 
 | 
			
		||||
@@ -91,56 +91,26 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
		key = parts[0]
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusUnauthorized, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		userEnabled, err := model.IsUserEnabled(token.UserId)
 | 
			
		||||
		userEnabled, err := model.CacheIsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled {
 | 
			
		||||
			c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": "用户已被封禁",
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("id", token.UserId)
 | 
			
		||||
		c.Set("token_id", token.Id)
 | 
			
		||||
		c.Set("token_name", token.Name)
 | 
			
		||||
		requestURL := c.Request.URL.String()
 | 
			
		||||
		consumeQuota := true
 | 
			
		||||
		if strings.HasPrefix(requestURL, "/v1/models") {
 | 
			
		||||
			consumeQuota = false
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("consume_quota", consumeQuota)
 | 
			
		||||
		if len(parts) > 1 {
 | 
			
		||||
			if model.IsAdmin(token.UserId) {
 | 
			
		||||
				c.Set("channelId", parts[1])
 | 
			
		||||
			} else {
 | 
			
		||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "普通用户不支持指定渠道",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,51 +25,24 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的渠道 ID",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.GetChannelById(id, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的渠道 ID",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "该渠道已被禁用",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			// Select a channel for the user
 | 
			
		||||
			var modelRequest ModelRequest
 | 
			
		||||
			var err error
 | 
			
		||||
			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
				err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			}
 | 
			
		||||
			err := common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的请求",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
@@ -84,10 +57,10 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "dall-e"
 | 
			
		||||
					modelRequest.Model = "dall-e-2"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "whisper-1"
 | 
			
		||||
				}
 | 
			
		||||
@@ -99,22 +72,16 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
			
		||||
					message = "数据库一致性已被破坏,请联系管理员"
 | 
			
		||||
				}
 | 
			
		||||
				c.JSON(http.StatusServiceUnavailable, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": message,
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusServiceUnavailable, message)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("channel", channel.Type)
 | 
			
		||||
		c.Set("channel_id", channel.Id)
 | 
			
		||||
		c.Set("channel_name", channel.Name)
 | 
			
		||||
		c.Set("model_mapping", channel.ModelMapping)
 | 
			
		||||
		c.Set("model_mapping", channel.GetModelMapping())
 | 
			
		||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
		c.Set("base_url", channel.BaseURL)
 | 
			
		||||
		c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
		switch channel.Type {
 | 
			
		||||
		case common.ChannelTypeAzure:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetUpLogger(server *gin.Engine) {
 | 
			
		||||
	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 | 
			
		||||
		var requestID string
 | 
			
		||||
		if param.Keys != nil {
 | 
			
		||||
			requestID = param.Keys[common.RequestIdKey].(string)
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
 | 
			
		||||
			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
 | 
			
		||||
			requestID,
 | 
			
		||||
			param.StatusCode,
 | 
			
		||||
			param.Latency,
 | 
			
		||||
			param.ClientIP,
 | 
			
		||||
			param.Method,
 | 
			
		||||
			param.Path,
 | 
			
		||||
		)
 | 
			
		||||
	}))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func RequestId() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		id := common.GetTimeString() + common.GetRandomString(8)
 | 
			
		||||
		c.Set(common.RequestIdKey, id)
 | 
			
		||||
		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
 | 
			
		||||
		c.Request = c.Request.WithContext(ctx)
 | 
			
		||||
		c.Header(common.RequestIdKey, id)
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
	c.JSON(statusCode, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
 | 
			
		||||
			"type":    "one_api_error",
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	c.Abort()
 | 
			
		||||
	common.LogError(c.Request.Context(), message)
 | 
			
		||||
}
 | 
			
		||||
@@ -10,15 +10,25 @@ type Ability struct {
 | 
			
		||||
	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 | 
			
		||||
	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 | 
			
		||||
	Enabled   bool   `json:"enabled"`
 | 
			
		||||
	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
			
		||||
	ability := Ability{}
 | 
			
		||||
	groupCol := "`group`"
 | 
			
		||||
	trueVal := "1"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		groupCol = `"group"`
 | 
			
		||||
		trueVal = "true"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	if common.UsingSQLite {
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
 | 
			
		||||
	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
 | 
			
		||||
	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
 | 
			
		||||
	if common.UsingSQLite || common.UsingPostgreSQL {
 | 
			
		||||
		err = channelQuery.Order("RANDOM()").First(&ability).Error
 | 
			
		||||
	} else {
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
 | 
			
		||||
		err = channelQuery.Order("RAND()").First(&ability).Error
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -40,6 +50,7 @@ func (channel *Channel) AddAbilities() error {
 | 
			
		||||
				Model:     model,
 | 
			
		||||
				ChannelId: channel.Id,
 | 
			
		||||
				Enabled:   channel.Status == common.ChannelStatusEnabled,
 | 
			
		||||
				Priority:  channel.Priority,
 | 
			
		||||
			}
 | 
			
		||||
			abilities = append(abilities, ability)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
@@ -20,14 +21,18 @@ var (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CacheGetTokenByKey(key string) (*Token, error) {
 | 
			
		||||
	keyCol := "`key`"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		keyCol = `"key"`
 | 
			
		||||
	}
 | 
			
		||||
	var token Token
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
		err := DB.Where("`key` = ?", key).First(&token).Error
 | 
			
		||||
		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 | 
			
		||||
		return &token, err
 | 
			
		||||
	}
 | 
			
		||||
	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err := DB.Where("`key` = ?", key).First(&token).Error
 | 
			
		||||
		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -159,6 +164,17 @@ func InitChannelCache() {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// sort by priority
 | 
			
		||||
	for group, model2channels := range newGroup2model2channels {
 | 
			
		||||
		for model, channels := range model2channels {
 | 
			
		||||
			sort.Slice(channels, func(i, j int) bool {
 | 
			
		||||
				return channels[i].GetPriority() > channels[j].GetPriority()
 | 
			
		||||
			})
 | 
			
		||||
			newGroup2model2channels[group][model] = channels
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	channelSyncLock.Lock()
 | 
			
		||||
	group2model2channels = newGroup2model2channels
 | 
			
		||||
	channelSyncLock.Unlock()
 | 
			
		||||
@@ -174,7 +190,7 @@ func SyncChannelCache(frequency int) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
	if !common.MemoryCacheEnabled {
 | 
			
		||||
		return GetRandomSatisfiedChannel(group, model)
 | 
			
		||||
	}
 | 
			
		||||
	channelSyncLock.RLock()
 | 
			
		||||
@@ -183,6 +199,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 | 
			
		||||
	if len(channels) == 0 {
 | 
			
		||||
		return nil, errors.New("channel not found")
 | 
			
		||||
	}
 | 
			
		||||
	idx := rand.Intn(len(channels))
 | 
			
		||||
	endIdx := len(channels)
 | 
			
		||||
	// choose by priority
 | 
			
		||||
	firstChannel := channels[0]
 | 
			
		||||
	if firstChannel.GetPriority() > 0 {
 | 
			
		||||
		for i := range channels {
 | 
			
		||||
			if channels[i].GetPriority() != firstChannel.GetPriority() {
 | 
			
		||||
				endIdx = i
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	idx := rand.Intn(endIdx)
 | 
			
		||||
	return channels[idx], nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,18 +11,19 @@ type Channel struct {
 | 
			
		||||
	Key                string  `json:"key" gorm:"not null;index"`
 | 
			
		||||
	Status             int     `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name               string  `json:"name" gorm:"index"`
 | 
			
		||||
	Weight             int     `json:"weight"`
 | 
			
		||||
	Weight             *uint   `json:"weight" gorm:"default:0"`
 | 
			
		||||
	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	TestTime           int64   `json:"test_time" gorm:"bigint"`
 | 
			
		||||
	ResponseTime       int     `json:"response_time"` // in milliseconds
 | 
			
		||||
	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
 | 
			
		||||
	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"`
 | 
			
		||||
	Other              string  `json:"other"`
 | 
			
		||||
	Balance            float64 `json:"balance"` // in USD
 | 
			
		||||
	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
			
		||||
	Models             string  `json:"models"`
 | 
			
		||||
	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
			
		||||
	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 | 
			
		||||
	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
			
		||||
	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
			
		||||
	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
@@ -37,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
 | 
			
		||||
	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
 | 
			
		||||
	keyCol := "`key`"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		keyCol = `"key"`
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
 | 
			
		||||
	return channels, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -52,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
 | 
			
		||||
	return &channel, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomChannel() (*Channel, error) {
 | 
			
		||||
	channel := Channel{}
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	if common.UsingSQLite {
 | 
			
		||||
		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
 | 
			
		||||
	} else {
 | 
			
		||||
		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
 | 
			
		||||
	}
 | 
			
		||||
	return &channel, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BatchInsertChannels(channels []Channel) error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(&channels).Error
 | 
			
		||||
@@ -78,6 +72,27 @@ func BatchInsertChannels(channels []Channel) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) GetPriority() int64 {
 | 
			
		||||
	if channel.Priority == nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return *channel.Priority
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) GetBaseURL() string {
 | 
			
		||||
	if channel.BaseURL == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return *channel.BaseURL
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) GetModelMapping() string {
 | 
			
		||||
	if channel.ModelMapping == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return *channel.ModelMapping
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) Insert() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(channel).Error
 | 
			
		||||
@@ -154,3 +169,13 @@ func updateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
		common.SysError("failed to update channel used quota: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteChannelByStatus(status int64) (int64, error) {
 | 
			
		||||
	result := DB.Where("status = ?", status).Delete(&Channel{})
 | 
			
		||||
	return result.RowsAffected, result.Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteDisabledChannel() (int64, error) {
 | 
			
		||||
	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
 | 
			
		||||
	return result.RowsAffected, result.Error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										40
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								model/log.go
									
									
									
									
									
								
							@@ -1,22 +1,25 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Log struct {
 | 
			
		||||
	Id               int    `json:"id"`
 | 
			
		||||
	UserId           int    `json:"user_id"`
 | 
			
		||||
	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"`
 | 
			
		||||
	Type             int    `json:"type" gorm:"index"`
 | 
			
		||||
	Id               int    `json:"id;index:idx_created_at_id,priority:1"`
 | 
			
		||||
	UserId           int    `json:"user_id" gorm:"index"`
 | 
			
		||||
	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
 | 
			
		||||
	Type             int    `json:"type" gorm:"index:idx_created_at_type"`
 | 
			
		||||
	Content          string `json:"content"`
 | 
			
		||||
	Username         string `json:"username" gorm:"index;default:''"`
 | 
			
		||||
	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
 | 
			
		||||
	TokenName        string `json:"token_name" gorm:"index;default:''"`
 | 
			
		||||
	ModelName        string `json:"model_name" gorm:"index;default:''"`
 | 
			
		||||
	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"default:0"`
 | 
			
		||||
	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 | 
			
		||||
	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
 | 
			
		||||
	ChannelId        int    `json:"channel" gorm:"index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -44,7 +47,8 @@ func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
			
		||||
	if !common.LogConsumeEnabled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -59,14 +63,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
 | 
			
		||||
		TokenName:        tokenName,
 | 
			
		||||
		ModelName:        modelName,
 | 
			
		||||
		Quota:            quota,
 | 
			
		||||
		ChannelId:        channelId,
 | 
			
		||||
	}
 | 
			
		||||
	err := DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to record log: " + err.Error())
 | 
			
		||||
		common.LogError(ctx, "failed to record log: "+err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
 | 
			
		||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 | 
			
		||||
	var tx *gorm.DB
 | 
			
		||||
	if logType == LogTypeUnknown {
 | 
			
		||||
		tx = DB
 | 
			
		||||
@@ -88,6 +93,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 | 
			
		||||
	if endTimestamp != 0 {
 | 
			
		||||
		tx = tx.Where("created_at <= ?", endTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
	if channel != 0 {
 | 
			
		||||
		tx = tx.Where("channel_id = ?", channel)
 | 
			
		||||
	}
 | 
			
		||||
	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
@@ -125,8 +133,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
 | 
			
		||||
	tx := DB.Table("logs").Select("sum(quota)")
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
 | 
			
		||||
	tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
	}
 | 
			
		||||
@@ -142,12 +150,15 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
	if modelName != "" {
 | 
			
		||||
		tx = tx.Where("model_name = ?", modelName)
 | 
			
		||||
	}
 | 
			
		||||
	if channel != 0 {
 | 
			
		||||
		tx = tx.Where("channel_id = ?", channel)
 | 
			
		||||
	}
 | 
			
		||||
	tx.Where("type = ?", LogTypeConsume).Scan("a)
 | 
			
		||||
	return quota
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
 | 
			
		||||
	tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
 | 
			
		||||
	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
	}
 | 
			
		||||
@@ -166,3 +177,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
	tx.Where("type = ?", LogTypeConsume).Scan(&token)
 | 
			
		||||
	return token
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
 | 
			
		||||
	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
 | 
			
		||||
	return result.RowsAffected, result.Error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
 | 
			
		||||
		if strings.HasPrefix(dsn, "postgres://") {
 | 
			
		||||
			// Use PostgreSQL
 | 
			
		||||
			common.SysLog("using PostgreSQL as database")
 | 
			
		||||
			common.UsingPostgreSQL = true
 | 
			
		||||
			return gorm.Open(postgres.New(postgres.Config{
 | 
			
		||||
				DSN:                  dsn,
 | 
			
		||||
				PreferSimpleProtocol: true, // disables implicit prepared statement usage
 | 
			
		||||
@@ -81,6 +82,7 @@ func InitDB() (err error) {
 | 
			
		||||
		if !common.IsMasterNode {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		common.SysLog("database migration started")
 | 
			
		||||
		err = db.AutoMigrate(&Channel{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
 
 | 
			
		||||
@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
 | 
			
		||||
	}
 | 
			
		||||
	redemption := &Redemption{}
 | 
			
		||||
 | 
			
		||||
	keyCol := "`key`"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		keyCol = `"key"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = DB.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
 | 
			
		||||
		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("无效的兑换码")
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserGroup(id int) (group string, err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
 | 
			
		||||
	groupCol := "`group`"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		groupCol = `"group"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
 | 
			
		||||
	return group, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) {
 | 
			
		||||
 | 
			
		||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
 | 
			
		||||
		addNewRecord(BatchUpdateTypeRequestCount, id, 1)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	updateUserUsedQuotaAndRequestCount(id, quota, 1)
 | 
			
		||||
@@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUserUsedQuota(id int, quota int) {
 | 
			
		||||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"used_quota": gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update user used quota: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUserRequestCount(id int, count int) {
 | 
			
		||||
	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update user request count: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUsernameById(id int) (username string) {
 | 
			
		||||
	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
 | 
			
		||||
	return username
 | 
			
		||||
 
 | 
			
		||||
@@ -6,13 +6,13 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	BatchUpdateTypeUserQuota = iota
 | 
			
		||||
	BatchUpdateTypeTokenQuota
 | 
			
		||||
	BatchUpdateTypeUsedQuotaAndRequestCount
 | 
			
		||||
	BatchUpdateTypeUsedQuota
 | 
			
		||||
	BatchUpdateTypeChannelUsedQuota
 | 
			
		||||
	BatchUpdateTypeRequestCount
 | 
			
		||||
	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var batchUpdateStores []map[int]int
 | 
			
		||||
@@ -51,7 +51,7 @@ func batchUpdate() {
 | 
			
		||||
		store := batchUpdateStores[i]
 | 
			
		||||
		batchUpdateStores[i] = make(map[int]int)
 | 
			
		||||
		batchUpdateLocks[i].Unlock()
 | 
			
		||||
 | 
			
		||||
		// TODO: maybe we can combine updates with same key?
 | 
			
		||||
		for key, value := range store {
 | 
			
		||||
			switch i {
 | 
			
		||||
			case BatchUpdateTypeUserQuota:
 | 
			
		||||
@@ -64,8 +64,10 @@ func batchUpdate() {
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to batch update token quota: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			case BatchUpdateTypeUsedQuotaAndRequestCount:
 | 
			
		||||
				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
 | 
			
		||||
			case BatchUpdateTypeUsedQuota:
 | 
			
		||||
				updateUserUsedQuota(key, value)
 | 
			
		||||
			case BatchUpdateTypeRequestCount:
 | 
			
		||||
				updateUserRequestCount(key, value)
 | 
			
		||||
			case BatchUpdateTypeChannelUsedQuota:
 | 
			
		||||
				updateChannelUsedQuota(key, value)
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										3
									
								
								pull_request_template.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pull_request_template.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
close #issue_number
 | 
			
		||||
 | 
			
		||||
我已确认该 PR 已自测通过,相关截图如下:
 | 
			
		||||
@@ -74,6 +74,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
 | 
			
		||||
			channelRoute.POST("/", controller.AddChannel)
 | 
			
		||||
			channelRoute.PUT("/", controller.UpdateChannel)
 | 
			
		||||
			channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
 | 
			
		||||
			channelRoute.DELETE("/:id", controller.DeleteChannel)
 | 
			
		||||
		}
 | 
			
		||||
		tokenRoute := apiRouter.Group("/token")
 | 
			
		||||
@@ -98,6 +99,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		}
 | 
			
		||||
		logRoute := apiRouter.Group("/log")
 | 
			
		||||
		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
 | 
			
		||||
		logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
 | 
			
		||||
		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
 | 
			
		||||
		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
 | 
			
		||||
		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	router.Use(middleware.CORS())
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/introduction
 | 
			
		||||
	modelsRouter := router.Group("/v1/models")
 | 
			
		||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
			
		||||
@@ -28,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/transcriptions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/translations", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/speech", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/files", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/files", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -283,7 +283,9 @@ function App() {
 | 
			
		||||
          </Suspense>
 | 
			
		||||
        }
 | 
			
		||||
      />
 | 
			
		||||
      <Route path='*' element={NotFound} />
 | 
			
		||||
      <Route path='*' element={
 | 
			
		||||
          <NotFound />
 | 
			
		||||
      } />
 | 
			
		||||
    </Routes>
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
 | 
			
		||||
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
 | 
			
		||||
import { Link } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 | 
			
		||||
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import { renderGroup, renderNumber } from '../helpers/render';
 | 
			
		||||
@@ -24,7 +24,7 @@ function renderType(type) {
 | 
			
		||||
    }
 | 
			
		||||
    type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
 | 
			
		||||
  }
 | 
			
		||||
  return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
 | 
			
		||||
  return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function renderBalance(type, balance) {
 | 
			
		||||
@@ -55,6 +55,7 @@ const ChannelsTable = () => {
 | 
			
		||||
  const [searchKeyword, setSearchKeyword] = useState('');
 | 
			
		||||
  const [searching, setSearching] = useState(false);
 | 
			
		||||
  const [updatingBalance, setUpdatingBalance] = useState(false);
 | 
			
		||||
  const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
 | 
			
		||||
 | 
			
		||||
  const loadChannels = async (startIdx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/?p=${startIdx}`);
 | 
			
		||||
@@ -96,7 +97,7 @@ const ChannelsTable = () => {
 | 
			
		||||
      });
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const manageChannel = async (id, action, idx) => {
 | 
			
		||||
  const manageChannel = async (id, action, idx, value) => {
 | 
			
		||||
    let data = { id };
 | 
			
		||||
    let res;
 | 
			
		||||
    switch (action) {
 | 
			
		||||
@@ -111,6 +112,23 @@ const ChannelsTable = () => {
 | 
			
		||||
        data.status = 2;
 | 
			
		||||
        res = await API.put('/api/channel/', data);
 | 
			
		||||
        break;
 | 
			
		||||
      case 'priority':
 | 
			
		||||
        if (value === '') {
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
        data.priority = parseInt(value);
 | 
			
		||||
        res = await API.put('/api/channel/', data);
 | 
			
		||||
        break;
 | 
			
		||||
      case 'weight':
 | 
			
		||||
        if (value === '') {
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
        data.weight = parseInt(value);
 | 
			
		||||
        if (data.weight < 0) {
 | 
			
		||||
          data.weight = 0;
 | 
			
		||||
        }
 | 
			
		||||
        res = await API.put('/api/channel/', data);
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
@@ -135,9 +153,23 @@ const ChannelsTable = () => {
 | 
			
		||||
        return <Label basic color='green'>已启用</Label>;
 | 
			
		||||
      case 2:
 | 
			
		||||
        return (
 | 
			
		||||
          <Label basic color='red'>
 | 
			
		||||
          <Popup
 | 
			
		||||
            trigger={<Label basic color='red'>
 | 
			
		||||
              已禁用
 | 
			
		||||
          </Label>
 | 
			
		||||
            </Label>}
 | 
			
		||||
            content='本渠道被手动禁用'
 | 
			
		||||
            basic
 | 
			
		||||
          />
 | 
			
		||||
        );
 | 
			
		||||
      case 3:
 | 
			
		||||
        return (
 | 
			
		||||
          <Popup
 | 
			
		||||
            trigger={<Label basic color='yellow'>
 | 
			
		||||
              已禁用
 | 
			
		||||
            </Label>}
 | 
			
		||||
            content='本渠道被程序自动禁用'
 | 
			
		||||
            basic
 | 
			
		||||
          />
 | 
			
		||||
        );
 | 
			
		||||
      default:
 | 
			
		||||
        return (
 | 
			
		||||
@@ -208,6 +240,17 @@ const ChannelsTable = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const deleteAllDisabledChannels = async () => {
 | 
			
		||||
    const res = await API.delete(`/api/channel/disabled`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showSuccess(`已删除所有禁用渠道,共计 ${data} 个`);
 | 
			
		||||
      await refresh();
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const updateChannelBalance = async (id, name, idx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/update_balance/${id}/`);
 | 
			
		||||
    const { success, message, balance } = res.data;
 | 
			
		||||
@@ -243,17 +286,15 @@ const ChannelsTable = () => {
 | 
			
		||||
    if (channels.length === 0) return;
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    let sortedChannels = [...channels];
 | 
			
		||||
    if (typeof sortedChannels[0][key] === 'string') {
 | 
			
		||||
    sortedChannels.sort((a, b) => {
 | 
			
		||||
        return ('' + a[key]).localeCompare(b[key]);
 | 
			
		||||
      });
 | 
			
		||||
      if (!isNaN(a[key])) {
 | 
			
		||||
        // If the value is numeric, subtract to sort
 | 
			
		||||
        return a[key] - b[key];
 | 
			
		||||
      } else {
 | 
			
		||||
      sortedChannels.sort((a, b) => {
 | 
			
		||||
        if (a[key] === b[key]) return 0;
 | 
			
		||||
        if (a[key] > b[key]) return -1;
 | 
			
		||||
        if (a[key] < b[key]) return 1;
 | 
			
		||||
      });
 | 
			
		||||
        // If the value is not numeric, sort as strings
 | 
			
		||||
        return ('' + a[key]).localeCompare(b[key]);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    if (sortedChannels[0].id === channels[0].id) {
 | 
			
		||||
      sortedChannels.reverse();
 | 
			
		||||
    }
 | 
			
		||||
@@ -261,6 +302,7 @@ const ChannelsTable = () => {
 | 
			
		||||
    setLoading(false);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <>
 | 
			
		||||
      <Form onSubmit={searchChannels}>
 | 
			
		||||
@@ -274,7 +316,19 @@ const ChannelsTable = () => {
 | 
			
		||||
          onChange={handleKeywordChange}
 | 
			
		||||
        />
 | 
			
		||||
      </Form>
 | 
			
		||||
      {
 | 
			
		||||
        showPrompt && (
 | 
			
		||||
          <Message onDismiss={() => {
 | 
			
		||||
            setShowPrompt(false);
 | 
			
		||||
            setPromptShown("channel-test");
 | 
			
		||||
          }}>
 | 
			
		||||
            当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
 | 
			
		||||
            模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。
 | 
			
		||||
 | 
			
		||||
            另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
 | 
			
		||||
          </Message>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Table basic compact size='small'>
 | 
			
		||||
        <Table.Header>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
@@ -334,6 +388,14 @@ const ChannelsTable = () => {
 | 
			
		||||
            >
 | 
			
		||||
              余额
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortChannel('priority');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              优先级
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
			
		||||
          </Table.Row>
 | 
			
		||||
        </Table.Header>
 | 
			
		||||
@@ -372,6 +434,22 @@ const ChannelsTable = () => {
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => {
 | 
			
		||||
                        manageChannel(
 | 
			
		||||
                          channel.id,
 | 
			
		||||
                          'priority',
 | 
			
		||||
                          idx,
 | 
			
		||||
                          event.target.value
 | 
			
		||||
                        );
 | 
			
		||||
                      }}>
 | 
			
		||||
                        <input style={{ maxWidth: '60px' }} />
 | 
			
		||||
                      </Input>}
 | 
			
		||||
                      content='渠道选择优先级,越高越优先'
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <div>
 | 
			
		||||
                      <Button
 | 
			
		||||
@@ -440,7 +518,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
        <Table.Footer>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
            <Table.HeaderCell colSpan='8'>
 | 
			
		||||
            <Table.HeaderCell colSpan='9'>
 | 
			
		||||
              <Button size='small' as={Link} to='/channel/add' loading={loading}>
 | 
			
		||||
                添加新的渠道
 | 
			
		||||
              </Button>
 | 
			
		||||
@@ -449,6 +527,20 @@ const ChannelsTable = () => {
 | 
			
		||||
              </Button>
 | 
			
		||||
              <Button size='small' onClick={updateAllChannelsBalance}
 | 
			
		||||
                      loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
 | 
			
		||||
              <Popup
 | 
			
		||||
                trigger={
 | 
			
		||||
                  <Button size='small' loading={loading}>
 | 
			
		||||
                    删除禁用渠道
 | 
			
		||||
                  </Button>
 | 
			
		||||
                }
 | 
			
		||||
                on='click'
 | 
			
		||||
                flowing
 | 
			
		||||
                hoverable
 | 
			
		||||
              >
 | 
			
		||||
                <Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
 | 
			
		||||
                  确认删除
 | 
			
		||||
                </Button>
 | 
			
		||||
              </Popup>
 | 
			
		||||
              <Pagination
 | 
			
		||||
                floated='right'
 | 
			
		||||
                activePage={activePage}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,8 +2,8 @@ import React, { useContext, useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
import { API, getLogo, showError, showSuccess } from '../helpers';
 | 
			
		||||
import { getOAuthState, onGitHubOAuthClicked } from './utils';
 | 
			
		||||
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
 | 
			
		||||
import { onGitHubOAuthClicked } from './utils';
 | 
			
		||||
 | 
			
		||||
const LoginForm = () => {
 | 
			
		||||
  const [inputs, setInputs] = useState({
 | 
			
		||||
@@ -68,8 +68,14 @@ const LoginForm = () => {
 | 
			
		||||
      if (success) {
 | 
			
		||||
        userDispatch({ type: 'login', payload: data });
 | 
			
		||||
        localStorage.setItem('user', JSON.stringify(data));
 | 
			
		||||
        navigate('/');
 | 
			
		||||
        if (username === 'root' && password === '123456') {
 | 
			
		||||
          navigate('/user/edit');
 | 
			
		||||
          showSuccess('登录成功!');
 | 
			
		||||
          showWarning('请立刻修改默认密码!');
 | 
			
		||||
        } else {
 | 
			
		||||
          navigate('/token');
 | 
			
		||||
          showSuccess('登录成功!');
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        showError(message);
 | 
			
		||||
      }
 | 
			
		||||
 
 | 
			
		||||
@@ -56,9 +56,10 @@ const LogsTable = () => {
 | 
			
		||||
    token_name: '',
 | 
			
		||||
    model_name: '',
 | 
			
		||||
    start_timestamp: timestamp2string(0),
 | 
			
		||||
    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
 | 
			
		||||
    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
 | 
			
		||||
    channel: ''
 | 
			
		||||
  });
 | 
			
		||||
  const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
 | 
			
		||||
  const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
 | 
			
		||||
 | 
			
		||||
  const [stat, setStat] = useState({
 | 
			
		||||
    quota: 0,
 | 
			
		||||
@@ -84,7 +85,7 @@ const LogsTable = () => {
 | 
			
		||||
  const getLogStat = async () => {
 | 
			
		||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
			
		||||
    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
			
		||||
    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
 | 
			
		||||
    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      setStat(data);
 | 
			
		||||
@@ -109,7 +110,7 @@ const LogsTable = () => {
 | 
			
		||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
			
		||||
    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
			
		||||
    if (isAdminUser) {
 | 
			
		||||
      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
			
		||||
      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
 | 
			
		||||
    } else {
 | 
			
		||||
      url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
			
		||||
    }
 | 
			
		||||
@@ -205,16 +206,9 @@ const LogsTable = () => {
 | 
			
		||||
        </Header>
 | 
			
		||||
        <Form>
 | 
			
		||||
          <Form.Group>
 | 
			
		||||
            {
 | 
			
		||||
              isAdminUser && (
 | 
			
		||||
                <Form.Input fluid label={'用户名称'} width={2} value={username}
 | 
			
		||||
                            placeholder={'可选值'} name='username'
 | 
			
		||||
                            onChange={handleInputChange} />
 | 
			
		||||
              )
 | 
			
		||||
            }
 | 
			
		||||
            <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
 | 
			
		||||
            <Form.Input fluid label={'令牌名称'} width={3} value={token_name}
 | 
			
		||||
                        placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
 | 
			
		||||
            <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
 | 
			
		||||
            <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
 | 
			
		||||
                        name='model_name'
 | 
			
		||||
                        onChange={handleInputChange} />
 | 
			
		||||
            <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
 | 
			
		||||
@@ -225,6 +219,19 @@ const LogsTable = () => {
 | 
			
		||||
                        onChange={handleInputChange} />
 | 
			
		||||
            <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          {
 | 
			
		||||
            isAdminUser && <>
 | 
			
		||||
              <Form.Group>
 | 
			
		||||
                <Form.Input fluid label={'渠道 ID'} width={3} value={channel}
 | 
			
		||||
                            placeholder='可选值' name='channel'
 | 
			
		||||
                            onChange={handleInputChange} />
 | 
			
		||||
                <Form.Input fluid label={'用户名称'} width={3} value={username}
 | 
			
		||||
                            placeholder={'可选值'} name='username'
 | 
			
		||||
                            onChange={handleInputChange} />
 | 
			
		||||
 | 
			
		||||
              </Form.Group>
 | 
			
		||||
            </>
 | 
			
		||||
          }
 | 
			
		||||
        </Form>
 | 
			
		||||
        <Table basic compact size='small'>
 | 
			
		||||
          <Table.Header>
 | 
			
		||||
@@ -238,6 +245,17 @@ const LogsTable = () => {
 | 
			
		||||
              >
 | 
			
		||||
                时间
 | 
			
		||||
              </Table.HeaderCell>
 | 
			
		||||
              {
 | 
			
		||||
                isAdminUser && <Table.HeaderCell
 | 
			
		||||
                  style={{ cursor: 'pointer' }}
 | 
			
		||||
                  onClick={() => {
 | 
			
		||||
                    sortLog('channel');
 | 
			
		||||
                  }}
 | 
			
		||||
                  width={1}
 | 
			
		||||
                >
 | 
			
		||||
                  渠道
 | 
			
		||||
                </Table.HeaderCell>
 | 
			
		||||
              }
 | 
			
		||||
              {
 | 
			
		||||
                isAdminUser && <Table.HeaderCell
 | 
			
		||||
                  style={{ cursor: 'pointer' }}
 | 
			
		||||
@@ -299,16 +317,16 @@ const LogsTable = () => {
 | 
			
		||||
                onClick={() => {
 | 
			
		||||
                  sortLog('quota');
 | 
			
		||||
                }}
 | 
			
		||||
                width={2}
 | 
			
		||||
                width={1}
 | 
			
		||||
              >
 | 
			
		||||
                消耗额度
 | 
			
		||||
                额度
 | 
			
		||||
              </Table.HeaderCell>
 | 
			
		||||
              <Table.HeaderCell
 | 
			
		||||
                style={{ cursor: 'pointer' }}
 | 
			
		||||
                onClick={() => {
 | 
			
		||||
                  sortLog('content');
 | 
			
		||||
                }}
 | 
			
		||||
                width={isAdminUser ? 4 : 5}
 | 
			
		||||
                width={isAdminUser ? 4 : 6}
 | 
			
		||||
              >
 | 
			
		||||
                详情
 | 
			
		||||
              </Table.HeaderCell>
 | 
			
		||||
@@ -326,6 +344,11 @@ const LogsTable = () => {
 | 
			
		||||
                return (
 | 
			
		||||
                  <Table.Row key={log.id}>
 | 
			
		||||
                    <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
 | 
			
		||||
                    {
 | 
			
		||||
                      isAdminUser && (
 | 
			
		||||
                        <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                      isAdminUser && (
 | 
			
		||||
                        <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
 | 
			
		||||
@@ -345,7 +368,7 @@ const LogsTable = () => {
 | 
			
		||||
 | 
			
		||||
          <Table.Footer>
 | 
			
		||||
            <Table.Row>
 | 
			
		||||
              <Table.HeaderCell colSpan={'9'}>
 | 
			
		||||
              <Table.HeaderCell colSpan={'10'}>
 | 
			
		||||
                <Select
 | 
			
		||||
                  placeholder='选择明细分类'
 | 
			
		||||
                  options={LOG_OPTIONS}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,9 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
 | 
			
		||||
import { API, showError, verifyJSON } from '../helpers';
 | 
			
		||||
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
 | 
			
		||||
 | 
			
		||||
const OperationSetting = () => {
 | 
			
		||||
  let now = new Date();
 | 
			
		||||
  let [inputs, setInputs] = useState({
 | 
			
		||||
    QuotaForNewUser: 0,
 | 
			
		||||
    QuotaForInviter: 0,
 | 
			
		||||
@@ -20,10 +21,11 @@ const OperationSetting = () => {
 | 
			
		||||
    DisplayInCurrencyEnabled: '',
 | 
			
		||||
    DisplayTokenStatEnabled: '',
 | 
			
		||||
    ApproximateTokenEnabled: '',
 | 
			
		||||
    RetryTimes: 0,
 | 
			
		||||
    RetryTimes: 0
 | 
			
		||||
  });
 | 
			
		||||
  const [originInputs, setOriginInputs] = useState({});
 | 
			
		||||
  let [loading, setLoading] = useState(false);
 | 
			
		||||
  let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
 | 
			
		||||
 | 
			
		||||
  const getOptions = async () => {
 | 
			
		||||
    const res = await API.get('/api/option/');
 | 
			
		||||
@@ -130,6 +132,17 @@ const OperationSetting = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const deleteHistoryLogs = async () => {
 | 
			
		||||
    console.log(inputs);
 | 
			
		||||
    const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showSuccess(`${data} 条日志已清理!`);
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    showError('日志清理失败:' + message);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <Grid columns={1}>
 | 
			
		||||
      <Grid.Column>
 | 
			
		||||
@@ -179,12 +192,6 @@ const OperationSetting = () => {
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.LogConsumeEnabled === 'true'}
 | 
			
		||||
              label='启用额度消费日志记录'
 | 
			
		||||
              name='LogConsumeEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.DisplayInCurrencyEnabled === 'true'}
 | 
			
		||||
              label='以货币形式显示额度'
 | 
			
		||||
@@ -208,6 +215,28 @@ const OperationSetting = () => {
 | 
			
		||||
            submitConfig('general').then();
 | 
			
		||||
          }}>保存通用设置</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            日志设置
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.LogConsumeEnabled === 'true'}
 | 
			
		||||
              label='启用额度消费日志记录'
 | 
			
		||||
              name='LogConsumeEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group widths={4}>
 | 
			
		||||
            <Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
 | 
			
		||||
                        name='history_timestamp'
 | 
			
		||||
                        onChange={(e, { name, value }) => {
 | 
			
		||||
                          setHistoryTimestamp(value);
 | 
			
		||||
                        }} />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={() => {
 | 
			
		||||
            deleteHistoryLogs().then();
 | 
			
		||||
          }}>清理历史日志</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            监控设置
 | 
			
		||||
          </Header>
 | 
			
		||||
 
 | 
			
		||||
@@ -130,7 +130,13 @@ const RedemptionsTable = () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    let sortedRedemptions = [...redemptions];
 | 
			
		||||
    sortedRedemptions.sort((a, b) => {
 | 
			
		||||
      if (!isNaN(a[key])) {
 | 
			
		||||
        // If the value is numeric, subtract to sort
 | 
			
		||||
        return a[key] - b[key];
 | 
			
		||||
      } else {
 | 
			
		||||
        // If the value is not numeric, sort as strings
 | 
			
		||||
        return ('' + a[key]).localeCompare(b[key]);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    if (sortedRedemptions[0].id === redemptions[0].id) {
 | 
			
		||||
      sortedRedemptions.reverse();
 | 
			
		||||
 
 | 
			
		||||
@@ -96,7 +96,7 @@ const TokensTable = () => {
 | 
			
		||||
    let nextUrl;
 | 
			
		||||
  
 | 
			
		||||
    if (nextLink) {
 | 
			
		||||
      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
 | 
			
		||||
      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
    } else {
 | 
			
		||||
      nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
    }
 | 
			
		||||
@@ -138,7 +138,7 @@ const TokensTable = () => {
 | 
			
		||||
    let defaultUrl;
 | 
			
		||||
  
 | 
			
		||||
    if (chatLink) {
 | 
			
		||||
      defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
 | 
			
		||||
      defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
    } else {
 | 
			
		||||
      defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
    }
 | 
			
		||||
@@ -228,7 +228,13 @@ const TokensTable = () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    let sortedTokens = [...tokens];
 | 
			
		||||
    sortedTokens.sort((a, b) => {
 | 
			
		||||
      if (!isNaN(a[key])) {
 | 
			
		||||
        // If the value is numeric, subtract to sort
 | 
			
		||||
        return a[key] - b[key];
 | 
			
		||||
      } else {
 | 
			
		||||
        // If the value is not numeric, sort as strings
 | 
			
		||||
        return ('' + a[key]).localeCompare(b[key]);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    if (sortedTokens[0].id === tokens[0].id) {
 | 
			
		||||
      sortedTokens.reverse();
 | 
			
		||||
 
 | 
			
		||||
@@ -133,7 +133,13 @@ const UsersTable = () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    let sortedUsers = [...users];
 | 
			
		||||
    sortedUsers.sort((a, b) => {
 | 
			
		||||
      if (!isNaN(a[key])) {
 | 
			
		||||
        // If the value is numeric, subtract to sort
 | 
			
		||||
        return a[key] - b[key];
 | 
			
		||||
      } else {
 | 
			
		||||
        // If the value is not numeric, sort as strings
 | 
			
		||||
        return ('' + a[key]).localeCompare(b[key]);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    if (sortedUsers[0].id === users[0].id) {
 | 
			
		||||
      sortedUsers.reverse();
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
 | 
			
		||||
  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
			
		||||
  { key: 19, text: '360 智脑', value: 19, color: 'blue' },
 | 
			
		||||
  { key: 23, text: '腾讯混元', value: 23, color: 'teal' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -187,3 +187,13 @@ export const verifyJSON = (str) => {
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export function shouldShowPrompt(id) {
 | 
			
		||||
  let prompt = localStorage.getItem(`prompt-${id}`);
 | 
			
		||||
  return !prompt;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function setPromptShown(id) {
 | 
			
		||||
  localStorage.setItem(`prompt-${id}`, 'true');
 | 
			
		||||
}
 | 
			
		||||
@@ -19,6 +19,8 @@ function type2secretPrompt(type) {
 | 
			
		||||
      return '按照如下格式输入:APPID|APISecret|APIKey';
 | 
			
		||||
    case 22:
 | 
			
		||||
      return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
 | 
			
		||||
    case 23:
 | 
			
		||||
      return '按照如下格式输入:AppId|SecretId|SecretKey';
 | 
			
		||||
    default:
 | 
			
		||||
      return '请输入渠道对应的鉴权密钥';
 | 
			
		||||
  }
 | 
			
		||||
@@ -58,25 +60,28 @@ const EditChannel = () => {
 | 
			
		||||
      let localModels = [];
 | 
			
		||||
      switch (value) {
 | 
			
		||||
        case 14:
 | 
			
		||||
          localModels = ['claude-instant-1', 'claude-2'];
 | 
			
		||||
          localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 11:
 | 
			
		||||
          localModels = ['PaLM-2'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 15:
 | 
			
		||||
          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
 | 
			
		||||
          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 17:
 | 
			
		||||
          localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
 | 
			
		||||
          localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 16:
 | 
			
		||||
          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
          localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 18:
 | 
			
		||||
          localModels = ['SparkDesk'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 19:
 | 
			
		||||
          localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
 | 
			
		||||
          localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 23:
 | 
			
		||||
          localModels = ['hunyuan'];
 | 
			
		||||
          break;
 | 
			
		||||
      }
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, models: localModels }));
 | 
			
		||||
@@ -174,7 +179,7 @@ const EditChannel = () => {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    let localInputs = inputs;
 | 
			
		||||
    if (localInputs.base_url.endsWith('/')) {
 | 
			
		||||
    if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
 | 
			
		||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
			
		||||
@@ -183,9 +188,6 @@ const EditChannel = () => {
 | 
			
		||||
    if (localInputs.type === 18 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = 'v2.1';
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.model_mapping === '') {
 | 
			
		||||
      localInputs.model_mapping = '{}';
 | 
			
		||||
    }
 | 
			
		||||
    let res;
 | 
			
		||||
    localInputs.models = localInputs.models.join(',');
 | 
			
		||||
    localInputs.group = localInputs.groups.join(',');
 | 
			
		||||
 
 | 
			
		||||
@@ -1,19 +1,12 @@
 | 
			
		||||
import React from 'react';
 | 
			
		||||
import { Segment, Header } from 'semantic-ui-react';
 | 
			
		||||
import { Message } from 'semantic-ui-react';
 | 
			
		||||
 | 
			
		||||
const NotFound = () => (
 | 
			
		||||
  <>
 | 
			
		||||
    <Header
 | 
			
		||||
      block
 | 
			
		||||
      as="h4"
 | 
			
		||||
      content="404"
 | 
			
		||||
      attached="top"
 | 
			
		||||
      icon="info"
 | 
			
		||||
      className="small-icon"
 | 
			
		||||
    />
 | 
			
		||||
    <Segment attached="bottom">
 | 
			
		||||
      未找到所请求的页面
 | 
			
		||||
    </Segment>
 | 
			
		||||
    <Message negative>
 | 
			
		||||
      <Message.Header>页面不存在</Message.Header>
 | 
			
		||||
      <p>请检查你的浏览器地址是否正确</p>
 | 
			
		||||
    </Message>
 | 
			
		||||
  </>
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -102,7 +102,7 @@ const EditUser = () => {
 | 
			
		||||
              label='密码'
 | 
			
		||||
              name='password'
 | 
			
		||||
              type={'password'}
 | 
			
		||||
              placeholder={'请输入新的密码'}
 | 
			
		||||
              placeholder={'请输入新的密码,最短 8 位'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={password}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user