mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 19:03:43 +08:00 
			
		
		
		
	Compare commits
	
		
			63 Commits
		
	
	
		
			v0.5.4-alp
			...
			v0.5.7-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 824444244b | ||
|  | fbe9985f57 | ||
|  | a27a5bcc06 | ||
|  | e28d4b1741 | ||
|  | f073592d39 | ||
|  | fa41ca9805 | ||
|  | e338de45b6 | ||
|  | 114587b46f | ||
|  | b4b4acc288 | ||
|  | d663de3e3a | ||
|  | a85ecace2e | ||
|  | fbdea91ea1 | ||
|  | 8d34b7a77e | ||
|  | cbd62011b8 | ||
|  | 4701897e2e | ||
|  | 0f6c132a80 | ||
|  | 3cac45dc85 | ||
|  | 47c08c72ce | ||
|  | 53b2cace0b | ||
|  | f0fc991b44 | ||
|  | 594f06e7b0 | ||
|  | 197d1d7a9d | ||
|  | f9b748c2ca | ||
|  | fd98463611 | ||
|  | f5a1cd3463 | ||
|  | 8651451e53 | ||
|  | 1c5bb97a42 | ||
|  | de868e4e4e | ||
|  | 1d258cc898 | ||
|  | 37e09d764c | ||
|  | 159b9e3369 | ||
|  | 92001986db | ||
|  | a5647b1ea7 | ||
|  | 215e54fc96 | ||
|  | ecf8a6d875 | ||
|  | 24df3e5f62 | ||
|  | 12ef9679a7 | ||
|  | 328aa68255 | ||
|  | 4335f005a6 | ||
|  | fe26a1448d | ||
|  | 42451d9d02 | ||
|  | 25c4c111ab | ||
|  | 0d50ad4b2b | ||
|  | 959bcdef88 | ||
|  | 39ae8075e4 | ||
|  | b57a0eca16 | ||
|  | 1b4cc78890 | ||
|  | 420c375140 | ||
|  | 01863d3e44 | ||
|  | d0a0e871e1 | ||
|  | bd6fe1e93c | ||
|  | c55bb67818 | ||
|  | 0f949c3782 | ||
|  | a721a5b6f9 | ||
|  | 276163affd | ||
|  | 621eb91b46 | ||
|  | 7e575abb95 | ||
|  | 9db93316c4 | ||
|  | c3dc315e75 | ||
|  | 04acdb1ccb | ||
|  | f0d5e102a3 | ||
|  | abbf2fded0 | ||
|  | ef2c5abb5b | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -4,4 +4,5 @@ upload | |||||||
| *.exe | *.exe | ||||||
| *.db | *.db | ||||||
| build | build | ||||||
| *.db-journal | *.db-journal | ||||||
|  | logs | ||||||
							
								
								
									
										94
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										94
									
								
								README.md
									
									
									
									
									
								
							| @@ -59,6 +59,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| > **Warning** | > **Warning** | ||||||
| > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 | ||||||
|  |  | ||||||
|  | > **Warning** | ||||||
|  | > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! | ||||||
|  |  | ||||||
| ## 功能 | ## 功能 | ||||||
| 1. 支持多种大模型: | 1. 支持多种大模型: | ||||||
|    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) |    + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) | ||||||
| @@ -69,12 +72,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
|    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) |    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) | ||||||
|    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) |    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) | ||||||
|    + [x] [360 智脑](https://ai.360.cn) |    + [x] [360 智脑](https://ai.360.cn) | ||||||
|  |    + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) | ||||||
| 2. 支持配置镜像以及众多第三方代理服务: | 2. 支持配置镜像以及众多第三方代理服务: | ||||||
|    + [x] [OpenAI-SB](https://openai-sb.com) |    + [x] [OpenAI-SB](https://openai-sb.com) | ||||||
|  |    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) | ||||||
|    + [x] [API2D](https://api2d.com/r/197971) |    + [x] [API2D](https://api2d.com/r/197971) | ||||||
|    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) |    + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) | ||||||
|    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) |    + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) | ||||||
|    + [x] [CloseAI](https://console.closeai-asia.com/r/2412) |  | ||||||
|    + [x] 自定义渠道:例如各种未收录的第三方代理服务 |    + [x] 自定义渠道:例如各种未收录的第三方代理服务 | ||||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||||
| @@ -91,23 +95,32 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | |||||||
| 15. 支持模型映射,重定向用户的请求模型。 | 15. 支持模型映射,重定向用户的请求模型。 | ||||||
| 16. 支持失败自动重试。 | 16. 支持失败自动重试。 | ||||||
| 17. 支持绘图接口。 | 17. 支持绘图接口。 | ||||||
| 18. 支持丰富的**自定义**设置, | 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 | ||||||
|  | 19. 支持丰富的**自定义**设置, | ||||||
|     1. 支持自定义系统名称,logo 以及页脚。 |     1. 支持自定义系统名称,logo 以及页脚。 | ||||||
|     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 |     2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 | ||||||
| 19. 支持通过系统访问令牌访问管理 API。 | 20. 支持通过系统访问令牌访问管理 API。 | ||||||
| 20. 支持 Cloudflare Turnstile 用户校验。 | 21. 支持 Cloudflare Turnstile 用户校验。 | ||||||
| 21. 支持用户管理,支持**多种用户登录注册方式**: | 22. 支持用户管理,支持**多种用户登录注册方式**: | ||||||
|     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 |     + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 | ||||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 |     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 |     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||||
|  |  | ||||||
| ## 部署 | ## 部署 | ||||||
| ### 基于 Docker 进行部署 | ### 基于 Docker 进行部署 | ||||||
| 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` | ```shell | ||||||
|  | # 使用 SQLite 的部署命令: | ||||||
|  | docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||||
|  | # 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。 | ||||||
|  | # 例如: | ||||||
|  | docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api | ||||||
|  | ``` | ||||||
|  |  | ||||||
| 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 | ||||||
|  |  | ||||||
| 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | 数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 | ||||||
|  |  | ||||||
|  | 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 | ||||||
|  |  | ||||||
| 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 | ||||||
|  |  | ||||||
| @@ -209,6 +222,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
|  |  | ||||||
| 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 | ||||||
|  |  | ||||||
|  | #### QChatGPT - QQ机器人 | ||||||
|  | 项目主页:https://github.com/RockChinQ/QChatGPT | ||||||
|  |  | ||||||
|  | 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。 | ||||||
|  |  | ||||||
|  | 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。 | ||||||
|  |  | ||||||
| ### 部署到第三方平台 | ### 部署到第三方平台 | ||||||
| <details> | <details> | ||||||
| <summary><strong>部署到 Sealos </strong></summary> | <summary><strong>部署到 Sealos </strong></summary> | ||||||
| @@ -227,7 +247,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
| <summary><strong>部署到 Zeabur</strong></summary> | <summary><strong>部署到 Zeabur</strong></summary> | ||||||
| <div> | <div> | ||||||
|  |  | ||||||
| > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。 | > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 | ||||||
|  |  | ||||||
| 1. 首先 fork 一份代码。 | 1. 首先 fork 一份代码。 | ||||||
| 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 | ||||||
| @@ -242,6 +262,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
| </div> | </div> | ||||||
| </details> | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><strong>部署到 Render</strong></summary> | ||||||
|  | <div> | ||||||
|  |  | ||||||
|  | > Render 提供免费额度,绑卡后可以进一步提升额度 | ||||||
|  |  | ||||||
|  | Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com | ||||||
|  |  | ||||||
|  | </div> | ||||||
|  | </details> | ||||||
|  |  | ||||||
| ## 配置 | ## 配置 | ||||||
| 系统本身开箱即用。 | 系统本身开箱即用。 | ||||||
|  |  | ||||||
| @@ -260,13 +291,20 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | |||||||
|  |  | ||||||
| 注意,具体的 API Base 的格式取决于你所使用的客户端。 | 注意,具体的 API Base 的格式取决于你所使用的客户端。 | ||||||
|  |  | ||||||
|  | 例如对于 OpenAI 的官方库: | ||||||
|  | ```bash | ||||||
|  | OPENAI_API_KEY="sk-xxxxxx" | ||||||
|  | OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | ||||||
|  | ``` | ||||||
|  |  | ||||||
| ```mermaid | ```mermaid | ||||||
| graph LR | graph LR | ||||||
|     A(用户) |     A(用户) | ||||||
|     A --->|请求| B(One API) |     A --->|使用 One API 分发的 key 进行请求| B(One API) | ||||||
|     B -->|中继请求| C(OpenAI) |     B -->|中继请求| C(OpenAI) | ||||||
|     B -->|中继请求| D(Azure) |     B -->|中继请求| D(Azure) | ||||||
|     B -->|中继请求| E(其他下游渠道) |     B -->|中继请求| E(其他 OpenAI API 格式下游渠道) | ||||||
|  |     B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 | 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 | ||||||
| @@ -275,8 +313,9 @@ graph LR | |||||||
| 不加的话将会使用负载均衡的方式使用多个渠道。 | 不加的话将会使用负载均衡的方式使用多个渠道。 | ||||||
|  |  | ||||||
| ### 环境变量 | ### 环境变量 | ||||||
| 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。 | 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 | ||||||
|    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` |    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` | ||||||
|  |    + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 | ||||||
| 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 | ||||||
|    + 例子:`SESSION_SECRET=random_string` |    + 例子:`SESSION_SECRET=random_string` | ||||||
| 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 | ||||||
| @@ -293,21 +332,31 @@ graph LR | |||||||
|      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 |      + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 | ||||||
| 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 | ||||||
|    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` |    + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` | ||||||
| 5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。 | 5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|  |    + 例子:`MEMORY_CACHE_ENABLED=true` | ||||||
|  | 6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 | ||||||
|    + 例子:`SYNC_FREQUENCY=60` |    + 例子:`SYNC_FREQUENCY=60` | ||||||
| 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | 7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 | ||||||
|    + 例子:`NODE_TYPE=slave` |    + 例子:`NODE_TYPE=slave` | ||||||
| 7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | 8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 | ||||||
|    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` |    + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` | ||||||
| 8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | 9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 | ||||||
|    + 例子:`CHANNEL_TEST_FREQUENCY=1440` |    + 例子:`CHANNEL_TEST_FREQUENCY=1440` | ||||||
| 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | 10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 | ||||||
|    + 例子:`POLLING_INTERVAL=5` |     + 例子:`POLLING_INTERVAL=5` | ||||||
|  | 11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 | ||||||
|  |     + 例子:`BATCH_UPDATE_ENABLED=true` | ||||||
|  |     + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 | ||||||
|  | 12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 | ||||||
|  |     + 例子:`BATCH_UPDATE_INTERVAL=5` | ||||||
|  | 13. 请求频率限制: | ||||||
|  |     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 | ||||||
|  |     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 | ||||||
|  |  | ||||||
| ### 命令行参数 | ### 命令行参数 | ||||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||||
|    + 例子:`--port 3000` |    + 例子:`--port 3000` | ||||||
| 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。 | 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 | ||||||
|    + 例子:`--log-dir ./logs` |    + 例子:`--log-dir ./logs` | ||||||
| 3. `--version`: 打印系统版本号并退出。 | 3. `--version`: 打印系统版本号并退出。 | ||||||
| 4. `--help`: 查看命令的使用帮助和参数说明。 | 4. `--help`: 查看命令的使用帮助和参数说明。 | ||||||
| @@ -339,8 +388,15 @@ https://openai.justsong.cn | |||||||
| 5. ChatGPT Next Web 报错:`Failed to fetch` | 5. ChatGPT Next Web 报错:`Failed to fetch` | ||||||
|    + 部署的时候不要设置 `BASE_URL`。 |    + 部署的时候不要设置 `BASE_URL`。 | ||||||
|    + 检查你的接口地址和 API Key 有没有填对。 |    + 检查你的接口地址和 API Key 有没有填对。 | ||||||
|  |    + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 | ||||||
| 6. 报错:`当前分组负载已饱和,请稍后再试` | 6. 报错:`当前分组负载已饱和,请稍后再试` | ||||||
|    + 上游通道 429 了。 |    + 上游通道 429 了。 | ||||||
|  | 7. 升级之后我的数据会丢失吗? | ||||||
|  |    + 如果使用 MySQL,不会。 | ||||||
|  |    + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 | ||||||
|  | 8. 升级之前数据库需要做变更吗? | ||||||
|  |    + 一般情况下不需要,系统将在初始化的时候自动调整。 | ||||||
|  |    + 如果需要的话,我会在更新日志中说明,并给出脚本。 | ||||||
|  |  | ||||||
| ## 相关项目 | ## 相关项目 | ||||||
| * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 | ||||||
| @@ -352,4 +408,4 @@ https://openai.justsong.cn | |||||||
|  |  | ||||||
| 同样适用于基于本项目的二开项目。 | 同样适用于基于本项目的二开项目。 | ||||||
|  |  | ||||||
| 依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 | 依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 | ||||||
|   | |||||||
| @@ -56,6 +56,7 @@ var EmailDomainWhitelist = []string{ | |||||||
| } | } | ||||||
|  |  | ||||||
| var DebugEnabled = os.Getenv("DEBUG") == "true" | var DebugEnabled = os.Getenv("DEBUG") == "true" | ||||||
|  | var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" | ||||||
|  |  | ||||||
| var LogConsumeEnabled = true | var LogConsumeEnabled = true | ||||||
|  |  | ||||||
| @@ -92,7 +93,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | |||||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | var RequestInterval = time.Duration(requestInterval) * time.Second | ||||||
|  |  | ||||||
| var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY | var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second | ||||||
|  |  | ||||||
|  | var BatchUpdateEnabled = false | ||||||
|  | var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RequestIdKey = "X-Oneapi-Request-Id" | ||||||
|  | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	RoleGuestUser  = 0 | 	RoleGuestUser  = 0 | ||||||
| @@ -111,10 +119,10 @@ var ( | |||||||
| // All duration's unit is seconds | // All duration's unit is seconds | ||||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | // Shouldn't larger then RateLimitKeyExpirationDuration | ||||||
| var ( | var ( | ||||||
| 	GlobalApiRateLimitNum            = 180 | 	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) | ||||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	GlobalWebRateLimitNum            = 60 | 	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) | ||||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||||
|  |  | ||||||
| 	UploadRateLimitNum            = 10 | 	UploadRateLimitNum            = 10 | ||||||
| @@ -148,55 +156,62 @@ const ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelStatusUnknown  = 0 | 	ChannelStatusUnknown          = 0 | ||||||
| 	ChannelStatusEnabled  = 1 // don't use 0, 0 is the default value! | 	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value! | ||||||
| 	ChannelStatusDisabled = 2 // also don't use 0 | 	ChannelStatusManuallyDisabled = 2 // also don't use 0 | ||||||
|  | 	ChannelStatusAutoDisabled     = 3 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ChannelTypeUnknown    = 0 | 	ChannelTypeUnknown        = 0 | ||||||
| 	ChannelTypeOpenAI     = 1 | 	ChannelTypeOpenAI         = 1 | ||||||
| 	ChannelTypeAPI2D      = 2 | 	ChannelTypeAPI2D          = 2 | ||||||
| 	ChannelTypeAzure      = 3 | 	ChannelTypeAzure          = 3 | ||||||
| 	ChannelTypeCloseAI    = 4 | 	ChannelTypeCloseAI        = 4 | ||||||
| 	ChannelTypeOpenAISB   = 5 | 	ChannelTypeOpenAISB       = 5 | ||||||
| 	ChannelTypeOpenAIMax  = 6 | 	ChannelTypeOpenAIMax      = 6 | ||||||
| 	ChannelTypeOhMyGPT    = 7 | 	ChannelTypeOhMyGPT        = 7 | ||||||
| 	ChannelTypeCustom     = 8 | 	ChannelTypeCustom         = 8 | ||||||
| 	ChannelTypeAILS       = 9 | 	ChannelTypeAILS           = 9 | ||||||
| 	ChannelTypeAIProxy    = 10 | 	ChannelTypeAIProxy        = 10 | ||||||
| 	ChannelTypePaLM       = 11 | 	ChannelTypePaLM           = 11 | ||||||
| 	ChannelTypeAPI2GPT    = 12 | 	ChannelTypeAPI2GPT        = 12 | ||||||
| 	ChannelTypeAIGC2D     = 13 | 	ChannelTypeAIGC2D         = 13 | ||||||
| 	ChannelTypeAnthropic  = 14 | 	ChannelTypeAnthropic      = 14 | ||||||
| 	ChannelTypeBaidu      = 15 | 	ChannelTypeBaidu          = 15 | ||||||
| 	ChannelTypeZhipu      = 16 | 	ChannelTypeZhipu          = 16 | ||||||
| 	ChannelTypeAli        = 17 | 	ChannelTypeAli            = 17 | ||||||
| 	ChannelTypeXunfei     = 18 | 	ChannelTypeXunfei         = 18 | ||||||
| 	ChannelType360        = 19 | 	ChannelType360            = 19 | ||||||
| 	ChannelTypeOpenRouter = 20 | 	ChannelTypeOpenRouter     = 20 | ||||||
|  | 	ChannelTypeAIProxyLibrary = 21 | ||||||
|  | 	ChannelTypeFastGPT        = 22 | ||||||
|  | 	ChannelTypeTencent        = 23 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ChannelBaseURLs = []string{ | var ChannelBaseURLs = []string{ | ||||||
| 	"",                               // 0 | 	"",                                  // 0 | ||||||
| 	"https://api.openai.com",         // 1 | 	"https://api.openai.com",            // 1 | ||||||
| 	"https://oa.api2d.net",           // 2 | 	"https://oa.api2d.net",              // 2 | ||||||
| 	"",                               // 3 | 	"",                                  // 3 | ||||||
| 	"https://api.closeai-proxy.xyz",  // 4 | 	"https://api.closeai-proxy.xyz",     // 4 | ||||||
| 	"https://api.openai-sb.com",      // 5 | 	"https://api.openai-sb.com",         // 5 | ||||||
| 	"https://api.openaimax.com",      // 6 | 	"https://api.openaimax.com",         // 6 | ||||||
| 	"https://api.ohmygpt.com",        // 7 | 	"https://api.ohmygpt.com",           // 7 | ||||||
| 	"",                               // 8 | 	"",                                  // 8 | ||||||
| 	"https://api.caipacity.com",      // 9 | 	"https://api.caipacity.com",         // 9 | ||||||
| 	"https://api.aiproxy.io",         // 10 | 	"https://api.aiproxy.io",            // 10 | ||||||
| 	"",                               // 11 | 	"",                                  // 11 | ||||||
| 	"https://api.api2gpt.com",        // 12 | 	"https://api.api2gpt.com",           // 12 | ||||||
| 	"https://api.aigc2d.com",         // 13 | 	"https://api.aigc2d.com",            // 13 | ||||||
| 	"https://api.anthropic.com",      // 14 | 	"https://api.anthropic.com",         // 14 | ||||||
| 	"https://aip.baidubce.com",       // 15 | 	"https://aip.baidubce.com",          // 15 | ||||||
| 	"https://open.bigmodel.cn",       // 16 | 	"https://open.bigmodel.cn",          // 16 | ||||||
| 	"https://dashscope.aliyuncs.com", // 17 | 	"https://dashscope.aliyuncs.com",    // 17 | ||||||
| 	"",                               // 18 | 	"",                                  // 18 | ||||||
| 	"https://ai.360.cn",              // 19 | 	"https://ai.360.cn",                 // 19 | ||||||
| 	"https://openrouter.ai/api",      // 20 | 	"https://openrouter.ai/api",         // 20 | ||||||
|  | 	"https://api.aiproxy.io",            // 21 | ||||||
|  | 	"https://fastgpt.run/api/openapi",   // 22 | ||||||
|  | 	"https://hunyuan.cloud.tencent.com", //23 | ||||||
| } | } | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ var ( | |||||||
| 	Port         = flag.Int("port", 3000, "the listening port") | 	Port         = flag.Int("port", 3000, "the listening port") | ||||||
| 	PrintVersion = flag.Bool("version", false, "print version and exit") | 	PrintVersion = flag.Bool("version", false, "print version and exit") | ||||||
| 	PrintHelp    = flag.Bool("help", false, "print help and exit") | 	PrintHelp    = flag.Bool("help", false, "print help and exit") | ||||||
| 	LogDir       = flag.String("log-dir", "", "specify the log directory") | 	LogDir       = flag.String("log-dir", "./logs", "specify the log directory") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func printHelp() { | func printHelp() { | ||||||
|   | |||||||
| @@ -1,29 +1,47 @@ | |||||||
| package common | package common | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"io" | 	"io" | ||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetupGinLog() { | const ( | ||||||
|  | 	loggerINFO  = "INFO" | ||||||
|  | 	loggerWarn  = "WARN" | ||||||
|  | 	loggerError = "ERR" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const maxLogCount = 1000000 | ||||||
|  |  | ||||||
|  | var logCount int | ||||||
|  | var setupLogLock sync.Mutex | ||||||
|  | var setupLogWorking bool | ||||||
|  |  | ||||||
|  | func SetupLogger() { | ||||||
| 	if *LogDir != "" { | 	if *LogDir != "" { | ||||||
| 		commonLogPath := filepath.Join(*LogDir, "common.log") | 		ok := setupLogLock.TryLock() | ||||||
| 		errorLogPath := filepath.Join(*LogDir, "error.log") | 		if !ok { | ||||||
| 		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 			log.Println("setup log is already working") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		defer func() { | ||||||
|  | 			setupLogLock.Unlock() | ||||||
|  | 			setupLogWorking = false | ||||||
|  | 		}() | ||||||
|  | 		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||||
|  | 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal("failed to open log file") | 			log.Fatal("failed to open log file") | ||||||
| 		} | 		} | ||||||
| 		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | 		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) | ||||||
| 		if err != nil { | 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) | ||||||
| 			log.Fatal("failed to open log file") |  | ||||||
| 		} |  | ||||||
| 		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) |  | ||||||
| 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -37,6 +55,36 @@ func SysError(s string) { | |||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func LogInfo(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerINFO, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LogWarn(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerWarn, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LogError(ctx context.Context, msg string) { | ||||||
|  | 	logHelper(ctx, loggerError, msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func logHelper(ctx context.Context, level string, msg string) { | ||||||
|  | 	writer := gin.DefaultErrorWriter | ||||||
|  | 	if level == loggerINFO { | ||||||
|  | 		writer = gin.DefaultWriter | ||||||
|  | 	} | ||||||
|  | 	id := ctx.Value(RequestIdKey) | ||||||
|  | 	now := time.Now() | ||||||
|  | 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||||
|  | 	logCount++ // we don't need accurate count, so no lock here | ||||||
|  | 	if logCount > maxLogCount && !setupLogWorking { | ||||||
|  | 		logCount = 0 | ||||||
|  | 		setupLogWorking = true | ||||||
|  | 		go func() { | ||||||
|  | 			SetupLogger() | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func FatalLog(v ...any) { | func FatalLog(v ...any) { | ||||||
| 	t := time.Now() | 	t := time.Now() | ||||||
| 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) | ||||||
|   | |||||||
| @@ -24,6 +24,7 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"gpt-3.5-turbo-0613":        0.75, | 	"gpt-3.5-turbo-0613":        0.75, | ||||||
| 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | 	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens | ||||||
| 	"gpt-3.5-turbo-16k-0613":    1.5, | 	"gpt-3.5-turbo-16k-0613":    1.5, | ||||||
|  | 	"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens | ||||||
| 	"text-ada-001":              0.2, | 	"text-ada-001":              0.2, | ||||||
| 	"text-babbage-001":          0.25, | 	"text-babbage-001":          0.25, | ||||||
| 	"text-curie-001":            1, | 	"text-curie-001":            1, | ||||||
| @@ -50,14 +51,15 @@ var ModelRatio = map[string]float64{ | |||||||
| 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens | ||||||
| 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens | ||||||
| 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | 	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens | ||||||
| 	"qwen-v1":                   0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag | 	"qwen-turbo":                0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"qwen-plus-v1":              0.5715, // Same as above | 	"qwen-plus":                 10,     // ¥0.14 / 1k tokens | ||||||
| 	"SparkDesk":                 0.8572, // TBD | 	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens | ||||||
|  | 	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens | ||||||
| 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | 	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens | ||||||
| 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | 	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | 	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | 	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens | ||||||
| 	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens | 	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func ModelRatio2JSONString() string { | func ModelRatio2JSONString() string { | ||||||
|   | |||||||
| @@ -171,6 +171,11 @@ func GetTimestamp() int64 { | |||||||
| 	return time.Now().Unix() | 	return time.Now().Unix() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetTimeString() string { | ||||||
|  | 	now := time.Now() | ||||||
|  | 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||||
|  | } | ||||||
|  |  | ||||||
| func Max(a int, b int) int { | func Max(a int, b int) int { | ||||||
| 	if a >= b { | 	if a >= b { | ||||||
| 		return a | 		return a | ||||||
| @@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int { | |||||||
| 	} | 	} | ||||||
| 	return num | 	return num | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func MessageWithRequestId(message string, id string) string { | ||||||
|  | 	return fmt.Sprintf("%s (request id: %s)", message, id) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		openAIError := OpenAIError{ | 		openAIError := OpenAIError{ | ||||||
| 			Message: err.Error(), | 			Message: err.Error(), | ||||||
| 			Type:    "one_api_error", | 			Type:    "upstream_error", | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"error": openAIError, | 			"error": openAIError, | ||||||
|   | |||||||
| @@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | |||||||
| } | } | ||||||
|  |  | ||||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | ||||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) | 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) | ||||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | |||||||
|  |  | ||||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||||
| 	if channel.BaseURL == "" { | 	if channel.GetBaseURL() == "" { | ||||||
| 		channel.BaseURL = baseURL | 		channel.BaseURL = &baseURL | ||||||
| 	} | 	} | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
| 	case common.ChannelTypeOpenAI: | 	case common.ChannelTypeOpenAI: | ||||||
| 		if channel.BaseURL != "" { | 		if channel.GetBaseURL() != "" { | ||||||
| 			baseURL = channel.BaseURL | 			baseURL = channel.GetBaseURL() | ||||||
| 		} | 		} | ||||||
| 	case common.ChannelTypeAzure: | 	case common.ChannelTypeAzure: | ||||||
| 		return 0, errors.New("尚未实现") | 		return 0, errors.New("尚未实现") | ||||||
| 	case common.ChannelTypeCustom: | 	case common.ChannelTypeCustom: | ||||||
| 		baseURL = channel.BaseURL | 		baseURL = channel.GetBaseURL() | ||||||
| 	case common.ChannelTypeCloseAI: | 	case common.ChannelTypeCloseAI: | ||||||
| 		return updateChannelCloseAIBalance(channel) | 		return updateChannelCloseAIBalance(channel) | ||||||
| 	case common.ChannelTypeOpenAISB: | 	case common.ChannelTypeOpenAISB: | ||||||
|   | |||||||
| @@ -14,7 +14,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { | func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { | ||||||
| 	switch channel.Type { | 	switch channel.Type { | ||||||
| 	case common.ChannelTypePaLM: | 	case common.ChannelTypePaLM: | ||||||
| 		fallthrough | 		fallthrough | ||||||
| @@ -32,15 +32,20 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr | |||||||
| 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil | ||||||
| 	case common.ChannelTypeAzure: | 	case common.ChannelTypeAzure: | ||||||
| 		request.Model = "gpt-35-turbo" | 		request.Model = "gpt-35-turbo" | ||||||
|  | 		defer func() { | ||||||
|  | 			if err != nil { | ||||||
|  | 				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
| 	default: | 	default: | ||||||
| 		request.Model = "gpt-3.5-turbo" | 		request.Model = "gpt-3.5-turbo" | ||||||
| 	} | 	} | ||||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||||
| 	if channel.Type == common.ChannelTypeAzure { | 	if channel.Type == common.ChannelTypeAzure { | ||||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) | 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) | ||||||
| 	} else { | 	} else { | ||||||
| 		if channel.BaseURL != "" { | 		if channel.GetBaseURL() != "" { | ||||||
| 			requestURL = channel.BaseURL | 			requestURL = channel.GetBaseURL() | ||||||
| 		} | 		} | ||||||
| 		requestURL += "/v1/chat/completions" | 		requestURL += "/v1/chat/completions" | ||||||
| 	} | 	} | ||||||
| @@ -136,7 +141,7 @@ func disableChannel(channelId int, channelName string, reason string) { | |||||||
| 	if common.RootUserEmail == "" { | 	if common.RootUserEmail == "" { | ||||||
| 		common.RootUserEmail = model.GetRootUserEmail() | 		common.RootUserEmail = model.GetRootUserEmail() | ||||||
| 	} | 	} | ||||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) | 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||||
| 	err := common.SendEmail(subject, common.RootUserEmail, content) | 	err := common.SendEmail(subject, common.RootUserEmail, content) | ||||||
|   | |||||||
| @@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
| 	channel.CreatedTime = common.GetTimestamp() | 	channel.CreatedTime = common.GetTimestamp() | ||||||
| 	keys := strings.Split(channel.Key, "\n") | 	keys := strings.Split(channel.Key, "\n") | ||||||
| 	channels := make([]model.Channel, 0) | 	channels := make([]model.Channel, 0, len(keys)) | ||||||
| 	for _, key := range keys { | 	for _, key := range keys { | ||||||
| 		if key == "" { | 		if key == "" { | ||||||
| 			continue | 			continue | ||||||
| @@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func DeleteDisabledChannel(c *gin.Context) { | ||||||
|  | 	rows, err := model.DeleteDisabledChannel() | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    rows, | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| func UpdateChannel(c *gin.Context) { | func UpdateChannel(c *gin.Context) { | ||||||
| 	channel := model.Channel{} | 	channel := model.Channel{} | ||||||
| 	err := c.ShouldBindJSON(&channel) | 	err := c.ShouldBindJSON(&channel) | ||||||
|   | |||||||
| @@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { | |||||||
|  |  | ||||||
| func GitHubOAuth(c *gin.Context) { | func GitHubOAuth(c *gin.Context) { | ||||||
| 	session := sessions.Default(c) | 	session := sessions.Default(c) | ||||||
|  | 	state := c.Query("state") | ||||||
|  | 	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { | ||||||
|  | 		c.JSON(http.StatusForbidden, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "state is empty or not same", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| 	username := session.Get("username") | 	username := session.Get("username") | ||||||
| 	if username != nil { | 	if username != nil { | ||||||
| 		GitHubBind(c) | 		GitHubBind(c) | ||||||
| @@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GenerateOAuthCode(c *gin.Context) { | ||||||
|  | 	session := sessions.Default(c) | ||||||
|  | 	state := common.GetRandomString(12) | ||||||
|  | 	session.Set("oauth_state", state) | ||||||
|  | 	err := session.Save() | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    state, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"one-api/model" | 	"one-api/model" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| @@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) { | |||||||
| 	username := c.Query("username") | 	username := c.Query("username") | ||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
|  | 	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetUserLogs(c *gin.Context) { | func GetUserLogs(c *gin.Context) { | ||||||
| @@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) { | |||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchAllLogs(c *gin.Context) { | func SearchAllLogs(c *gin.Context) { | ||||||
| 	keyword := c.Query("keyword") | 	keyword := c.Query("keyword") | ||||||
| 	logs, err := model.SearchAllLogs(keyword) | 	logs, err := model.SearchAllLogs(keyword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func SearchUserLogs(c *gin.Context) { | func SearchUserLogs(c *gin.Context) { | ||||||
| @@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) { | |||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	logs, err := model.SearchUserLogs(userId, keyword) | 	logs, err := model.SearchUserLogs(userId, keyword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(http.StatusOK, gin.H{ | ||||||
| 			"success": false, | 			"success": false, | ||||||
| 			"message": err.Error(), | 			"message": err.Error(), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data":    logs, | 		"data":    logs, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsStat(c *gin.Context) { | func GetLogsStat(c *gin.Context) { | ||||||
| @@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) { | |||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	username := c.Query("username") | 	username := c.Query("username") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
|  | 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) | ||||||
| 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") | 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data": gin.H{ | 		"data": gin.H{ | ||||||
| @@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) { | |||||||
| 			//"token": tokenNum, | 			//"token": tokenNum, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetLogsSelfStat(c *gin.Context) { | func GetLogsSelfStat(c *gin.Context) { | ||||||
| @@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) { | |||||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||||
| 	tokenName := c.Query("token_name") | 	tokenName := c.Query("token_name") | ||||||
| 	modelName := c.Query("model_name") | 	modelName := c.Query("model_name") | ||||||
| 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | 	channel, _ := strconv.Atoi(c.Query("channel")) | ||||||
|  | 	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) | ||||||
| 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) | ||||||
| 	c.JSON(200, gin.H{ | 	c.JSON(http.StatusOK, gin.H{ | ||||||
| 		"success": true, | 		"success": true, | ||||||
| 		"message": "", | 		"message": "", | ||||||
| 		"data": gin.H{ | 		"data": gin.H{ | ||||||
| @@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) { | |||||||
| 			//"token": tokenNum, | 			//"token": tokenNum, | ||||||
| 		}, | 		}, | ||||||
| 	}) | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DeleteHistoryLogs(c *gin.Context) { | ||||||
|  | 	targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) | ||||||
|  | 	if targetTimestamp == 0 { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": "target timestamp is required", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	count, err := model.DeleteOldLog(targetTimestamp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 			"success": false, | ||||||
|  | 			"message": err.Error(), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.JSON(http.StatusOK, gin.H{ | ||||||
|  | 		"success": true, | ||||||
|  | 		"message": "", | ||||||
|  | 		"data":    count, | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
| } | } | ||||||
|   | |||||||
| @@ -117,6 +117,15 @@ func init() { | |||||||
| 			Root:       "gpt-3.5-turbo-16k-0613", | 			Root:       "gpt-3.5-turbo-16k-0613", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "gpt-3.5-turbo-instruct", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "openai", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "gpt-3.5-turbo-instruct", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "gpt-4", | 			Id:         "gpt-4", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| @@ -343,21 +352,30 @@ func init() { | |||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "qwen-v1", | 			Id:         "qwen-turbo", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| 			Created:    1677649963, | 			Created:    1677649963, | ||||||
| 			OwnedBy:    "ali", | 			OwnedBy:    "ali", | ||||||
| 			Permission: permission, | 			Permission: permission, | ||||||
| 			Root:       "qwen-v1", | 			Root:       "qwen-turbo", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "qwen-plus-v1", | 			Id:         "qwen-plus", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| 			Created:    1677649963, | 			Created:    1677649963, | ||||||
| 			OwnedBy:    "ali", | 			OwnedBy:    "ali", | ||||||
| 			Permission: permission, | 			Permission: permission, | ||||||
| 			Root:       "qwen-plus-v1", | 			Root:       "qwen-plus", | ||||||
|  | 			Parent:     nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "text-embedding-v1", | ||||||
|  | 			Object:     "model", | ||||||
|  | 			Created:    1677649963, | ||||||
|  | 			OwnedBy:    "ali", | ||||||
|  | 			Permission: permission, | ||||||
|  | 			Root:       "text-embedding-v1", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| @@ -406,12 +424,12 @@ func init() { | |||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			Id:         "360GPT_S2_V9.4", | 			Id:         "hunyuan", | ||||||
| 			Object:     "model", | 			Object:     "model", | ||||||
| 			Created:    1677649963, | 			Created:    1677649963, | ||||||
| 			OwnedBy:    "360", | 			OwnedBy:    "tencent", | ||||||
| 			Permission: permission, | 			Permission: permission, | ||||||
| 			Root:       "360GPT_S2_V9.4", | 			Root:       "hunyuan", | ||||||
| 			Parent:     nil, | 			Parent:     nil, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) { | |||||||
| 		if option.Value == "true" && common.GitHubClientId == "" { | 		if option.Value == "true" && common.GitHubClientId == "" { | ||||||
| 			c.JSON(http.StatusOK, gin.H{ | 			c.JSON(http.StatusOK, gin.H{ | ||||||
| 				"success": false, | 				"success": false, | ||||||
| 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", | 				"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", | ||||||
| 			}) | 			}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|   | |||||||
							
								
								
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,220 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 | ||||||
|  |  | ||||||
|  | type AIProxyLibraryRequest struct { | ||||||
|  | 	Model     string `json:"model"` | ||||||
|  | 	Query     string `json:"query"` | ||||||
|  | 	LibraryId string `json:"libraryId"` | ||||||
|  | 	Stream    bool   `json:"stream"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryError struct { | ||||||
|  | 	ErrCode int    `json:"errCode"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryDocument struct { | ||||||
|  | 	Title string `json:"title"` | ||||||
|  | 	URL   string `json:"url"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryResponse struct { | ||||||
|  | 	Success   bool                     `json:"success"` | ||||||
|  | 	Answer    string                   `json:"answer"` | ||||||
|  | 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||||
|  | 	AIProxyLibraryError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AIProxyLibraryStreamResponse struct { | ||||||
|  | 	Content   string                   `json:"content"` | ||||||
|  | 	Finish    bool                     `json:"finish"` | ||||||
|  | 	Model     string                   `json:"model"` | ||||||
|  | 	Documents []AIProxyLibraryDocument `json:"documents"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { | ||||||
|  | 	query := "" | ||||||
|  | 	if len(request.Messages) != 0 { | ||||||
|  | 		query = request.Messages[len(request.Messages)-1].Content | ||||||
|  | 	} | ||||||
|  | 	return &AIProxyLibraryRequest{ | ||||||
|  | 		Model:  request.Model, | ||||||
|  | 		Stream: request.Stream, | ||||||
|  | 		Query:  query, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { | ||||||
|  | 	if len(documents) == 0 { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	content := "\n\n参考文档:\n" | ||||||
|  | 	for i, document := range documents { | ||||||
|  | 		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) | ||||||
|  | 	} | ||||||
|  | 	return content | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { | ||||||
|  | 	content := response.Answer + aiProxyDocuments2Markdown(response.Documents) | ||||||
|  | 	choice := OpenAITextResponseChoice{ | ||||||
|  | 		Index: 0, | ||||||
|  | 		Message: Message{ | ||||||
|  | 			Role:    "assistant", | ||||||
|  | 			Content: content, | ||||||
|  | 		}, | ||||||
|  | 		FinishReason: "stop", | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Choices: []OpenAITextResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = aiProxyDocuments2Markdown(documents) | ||||||
|  | 	choice.FinishReason = &stopFinishReason | ||||||
|  | 	return &ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "", | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 	choice.Delta.Content = response.Content | ||||||
|  | 	return &ChatCompletionsStreamResponse{ | ||||||
|  | 		Id:      common.GetUUID(), | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   response.Model, | ||||||
|  | 		Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var usage Usage | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	var documents []AIProxyLibraryDocument | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var AIProxyLibraryResponse AIProxyLibraryStreamResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			if len(AIProxyLibraryResponse.Documents) != 0 { | ||||||
|  | 				documents = AIProxyLibraryResponse.Documents | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			response := documentsAIProxyLibrary(documents) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var AIProxyLibraryResponse AIProxyLibraryResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if AIProxyLibraryResponse.ErrCode != 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: AIProxyLibraryResponse.Message, | ||||||
|  | 				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode), | ||||||
|  | 				Code:    AIProxyLibraryResponse.ErrCode, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
| @@ -35,6 +35,29 @@ type AliChatRequest struct { | |||||||
| 	Parameters AliParameters `json:"parameters,omitempty"` | 	Parameters AliParameters `json:"parameters,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingRequest struct { | ||||||
|  | 	Model string `json:"model"` | ||||||
|  | 	Input struct { | ||||||
|  | 		Texts []string `json:"texts"` | ||||||
|  | 	} `json:"input"` | ||||||
|  | 	Parameters *struct { | ||||||
|  | 		TextType string `json:"text_type,omitempty"` | ||||||
|  | 	} `json:"parameters,omitempty"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbedding struct { | ||||||
|  | 	Embedding []float64 `json:"embedding"` | ||||||
|  | 	TextIndex int       `json:"text_index"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type AliEmbeddingResponse struct { | ||||||
|  | 	Output struct { | ||||||
|  | 		Embeddings []AliEmbedding `json:"embeddings"` | ||||||
|  | 	} `json:"output"` | ||||||
|  | 	Usage AliUsage `json:"usage"` | ||||||
|  | 	AliError | ||||||
|  | } | ||||||
|  |  | ||||||
| type AliError struct { | type AliError struct { | ||||||
| 	Code      string `json:"code"` | 	Code      string `json:"code"` | ||||||
| 	Message   string `json:"message"` | 	Message   string `json:"message"` | ||||||
| @@ -44,6 +67,7 @@ type AliError struct { | |||||||
| type AliUsage struct { | type AliUsage struct { | ||||||
| 	InputTokens  int `json:"input_tokens"` | 	InputTokens  int `json:"input_tokens"` | ||||||
| 	OutputTokens int `json:"output_tokens"` | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type AliOutput struct { | type AliOutput struct { | ||||||
| @@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { | ||||||
|  | 	return &AliEmbeddingRequest{ | ||||||
|  | 		Model: "text-embedding-v1", | ||||||
|  | 		Input: struct { | ||||||
|  | 			Texts []string `json:"texts"` | ||||||
|  | 		}{ | ||||||
|  | 			Texts: request.ParseInput(), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var aliResponse AliEmbeddingResponse | ||||||
|  | 	err := json.NewDecoder(resp.Body).Decode(&aliResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if aliResponse.Code != "" { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: aliResponse.Message, | ||||||
|  | 				Type:    aliResponse.Code, | ||||||
|  | 				Param:   aliResponse.RequestId, | ||||||
|  | 				Code:    aliResponse.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|  | 	openAIEmbeddingResponse := OpenAIEmbeddingResponse{ | ||||||
|  | 		Object: "list", | ||||||
|  | 		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), | ||||||
|  | 		Model:  "text-embedding-v1", | ||||||
|  | 		Usage:  Usage{TotalTokens: response.Usage.TotalTokens}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, item := range response.Output.Embeddings { | ||||||
|  | 		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ | ||||||
|  | 			Object:    `embedding`, | ||||||
|  | 			Index:     item.TextIndex, | ||||||
|  | 			Embedding: item.Embedding, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return &openAIEmbeddingResponse | ||||||
|  | } | ||||||
|  |  | ||||||
| func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { | ||||||
| 	choice := OpenAITextResponseChoice{ | 	choice := OpenAITextResponseChoice{ | ||||||
| 		Index: 0, | 		Index: 0, | ||||||
|   | |||||||
| @@ -2,7 +2,9 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @@ -17,6 +19,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
|  |  | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
|  |  | ||||||
| @@ -29,6 +32,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  | 	if userQuota-preConsumedQuota < 0 { | ||||||
|  | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
| @@ -91,7 +97,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	} | 	} | ||||||
| 	var audioResponse AudioResponse | 	var audioResponse AudioResponse | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func(ctx context.Context) { | ||||||
| 		go func() { | 		go func() { | ||||||
| 			quota := countTokenText(audioResponse.Text, audioModel) | 			quota := countTokenText(audioResponse.Text, audioModel) | ||||||
| 			quotaDelta := quota - preConsumedQuota | 			quotaDelta := quota - preConsumedQuota | ||||||
| @@ -106,13 +112,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 			if quota != 0 { | 			if quota != 0 { | ||||||
| 				tokenName := c.GetString("token_name") | 				tokenName := c.GetString("token_name") | ||||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) | 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) | ||||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
| 				channelId := c.GetInt("channel_id") | 				channelId := c.GetInt("channel_id") | ||||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | 				model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 	}() | 	}(c.Request.Context()) | ||||||
|  |  | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom | |||||||
| } | } | ||||||
|  |  | ||||||
| func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { | ||||||
| 	baiduEmbeddingRequest := BaiduEmbeddingRequest{ | 	return &BaiduEmbeddingRequest{ | ||||||
| 		Input: nil, | 		Input: request.ParseInput(), | ||||||
| 	} | 	} | ||||||
| 	switch request.Input.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		baiduEmbeddingRequest.Input = []string{request.Input.(string)} |  | ||||||
| 	case []any: |  | ||||||
| 		for _, item := range request.Input.([]any) { |  | ||||||
| 			if str, ok := item.(string); ok { |  | ||||||
| 				baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return &baiduEmbeddingRequest |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
|  |  | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	consumeQuota := c.GetBool("consume_quota") | 	consumeQuota := c.GetBool("consume_quota") | ||||||
| 	group := c.GetString("group") | 	group := c.GetString("group") | ||||||
| @@ -97,7 +99,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | 	quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||||||
|  |  | ||||||
| 	if consumeQuota && userQuota-quota < 0 { | 	if consumeQuota && userQuota-quota < 0 { | ||||||
| 		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||||||
| @@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 	} | 	} | ||||||
| 	var textResponse ImageResponse | 	var textResponse ImageResponse | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func(ctx context.Context) { | ||||||
| 		if consumeQuota { | 		if consumeQuota { | ||||||
| 			err := model.PostConsumeTokenQuota(tokenId, quota) | 			err := model.PostConsumeTokenQuota(tokenId, quota) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode | |||||||
| 			if quota != 0 { | 			if quota != 0 { | ||||||
| 				tokenName := c.GetString("token_name") | 				tokenName := c.GetString("token_name") | ||||||
| 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) | 				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) | ||||||
| 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
| 				channelId := c.GetInt("channel_id") | 				channelId := c.GetInt("channel_id") | ||||||
| 				model.UpdateChannelUsedQuota(channelId, quota) | 				model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}(c.Request.Context()) | ||||||
|  |  | ||||||
| 	if consumeQuota { | 	if consumeQuota { | ||||||
| 		responseBody, err := io.ReadAll(resp.Body) | 		responseBody, err := io.ReadAll(resp.Body) | ||||||
|   | |||||||
							
								
								
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								controller/relay-tencent.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,287 @@ | |||||||
|  | package controller | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"crypto/hmac" | ||||||
|  | 	"crypto/sha1" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"sort" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // https://cloud.tencent.com/document/product/1729/97732 | ||||||
|  |  | ||||||
|  | type TencentMessage struct { | ||||||
|  | 	Role    string `json:"role"` | ||||||
|  | 	Content string `json:"content"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentChatRequest struct { | ||||||
|  | 	AppId    int64  `json:"app_id"`    // 腾讯云账号的 APPID | ||||||
|  | 	SecretId string `json:"secret_id"` // 官网 SecretId | ||||||
|  | 	// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 | ||||||
|  | 	// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 | ||||||
|  | 	Timestamp int64 `json:"timestamp"` | ||||||
|  | 	// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, | ||||||
|  | 	// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 | ||||||
|  | 	Expired int64  `json:"expired"` | ||||||
|  | 	QueryID string `json:"query_id"` //请求 Id,用于问题排查 | ||||||
|  | 	// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 | ||||||
|  | 	// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 | ||||||
|  | 	// 建议该参数和 top_p 只设置1个,不要同时更改 top_p | ||||||
|  | 	Temperature float64 `json:"temperature"` | ||||||
|  | 	// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 | ||||||
|  | 	// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 | ||||||
|  | 	// 建议该参数和 temperature 只设置1个,不要同时更改 | ||||||
|  | 	TopP float64 `json:"top_p"` | ||||||
|  | 	// Stream 0:同步,1:流式 (默认,协议:SSE) | ||||||
|  | 	// 同步请求超时:60s,如果内容较长建议使用流式 | ||||||
|  | 	Stream int `json:"stream"` | ||||||
|  | 	// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 | ||||||
|  | 	// 输入 content 总数最大支持 3000 token。 | ||||||
|  | 	Messages []TencentMessage `json:"messages"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentError struct { | ||||||
|  | 	Code    int    `json:"code"` | ||||||
|  | 	Message string `json:"message"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentUsage struct { | ||||||
|  | 	InputTokens  int `json:"input_tokens"` | ||||||
|  | 	OutputTokens int `json:"output_tokens"` | ||||||
|  | 	TotalTokens  int `json:"total_tokens"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentResponseChoices struct { | ||||||
|  | 	FinishReason string         `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 | ||||||
|  | 	Messages     TencentMessage `json:"messages,omitempty"`      // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 | ||||||
|  | 	Delta        TencentMessage `json:"delta,omitempty"`         // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type TencentChatResponse struct { | ||||||
|  | 	Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 | ||||||
|  | 	Created string                   `json:"created,omitempty"` // unix 时间戳的字符串 | ||||||
|  | 	Id      string                   `json:"id,omitempty"`      // 会话 id | ||||||
|  | 	Usage   Usage                    `json:"usage,omitempty"`   // token 数量 | ||||||
|  | 	Error   TencentError             `json:"error,omitempty"`   // 错误信息 注意:此字段可能返回 null,表示取不到有效值 | ||||||
|  | 	Note    string                   `json:"note,omitempty"`    // 注释 | ||||||
|  | 	ReqID   string                   `json:"req_id,omitempty"`  // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { | ||||||
|  | 	messages := make([]TencentMessage, 0, len(request.Messages)) | ||||||
|  | 	for i := 0; i < len(request.Messages); i++ { | ||||||
|  | 		message := request.Messages[i] | ||||||
|  | 		if message.Role == "system" { | ||||||
|  | 			messages = append(messages, TencentMessage{ | ||||||
|  | 				Role:    "user", | ||||||
|  | 				Content: message.Content, | ||||||
|  | 			}) | ||||||
|  | 			messages = append(messages, TencentMessage{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: "Okay", | ||||||
|  | 			}) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		messages = append(messages, TencentMessage{ | ||||||
|  | 			Content: message.Content, | ||||||
|  | 			Role:    message.Role, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	stream := 0 | ||||||
|  | 	if request.Stream { | ||||||
|  | 		stream = 1 | ||||||
|  | 	} | ||||||
|  | 	return &TencentChatRequest{ | ||||||
|  | 		Timestamp:   common.GetTimestamp(), | ||||||
|  | 		Expired:     common.GetTimestamp() + 24*60*60, | ||||||
|  | 		QueryID:     common.GetUUID(), | ||||||
|  | 		Temperature: request.Temperature, | ||||||
|  | 		TopP:        request.TopP, | ||||||
|  | 		Stream:      stream, | ||||||
|  | 		Messages:    messages, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { | ||||||
|  | 	fullTextResponse := OpenAITextResponse{ | ||||||
|  | 		Object:  "chat.completion", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Usage:   response.Usage, | ||||||
|  | 	} | ||||||
|  | 	if len(response.Choices) > 0 { | ||||||
|  | 		choice := OpenAITextResponseChoice{ | ||||||
|  | 			Index: 0, | ||||||
|  | 			Message: Message{ | ||||||
|  | 				Role:    "assistant", | ||||||
|  | 				Content: response.Choices[0].Messages.Content, | ||||||
|  | 			}, | ||||||
|  | 			FinishReason: response.Choices[0].FinishReason, | ||||||
|  | 		} | ||||||
|  | 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||||
|  | 	} | ||||||
|  | 	return &fullTextResponse | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { | ||||||
|  | 	response := ChatCompletionsStreamResponse{ | ||||||
|  | 		Object:  "chat.completion.chunk", | ||||||
|  | 		Created: common.GetTimestamp(), | ||||||
|  | 		Model:   "tencent-hunyuan", | ||||||
|  | 	} | ||||||
|  | 	if len(TencentResponse.Choices) > 0 { | ||||||
|  | 		var choice ChatCompletionsStreamResponseChoice | ||||||
|  | 		choice.Delta.Content = TencentResponse.Choices[0].Delta.Content | ||||||
|  | 		if TencentResponse.Choices[0].FinishReason == "stop" { | ||||||
|  | 			choice.FinishReason = &stopFinishReason | ||||||
|  | 		} | ||||||
|  | 		response.Choices = append(response.Choices, choice) | ||||||
|  | 	} | ||||||
|  | 	return &response | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||||||
|  | 	var responseText string | ||||||
|  | 	scanner := bufio.NewScanner(resp.Body) | ||||||
|  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||||
|  | 		if atEOF && len(data) == 0 { | ||||||
|  | 			return 0, nil, nil | ||||||
|  | 		} | ||||||
|  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||||
|  | 			return i + 1, data[0:i], nil | ||||||
|  | 		} | ||||||
|  | 		if atEOF { | ||||||
|  | 			return len(data), data, nil | ||||||
|  | 		} | ||||||
|  | 		return 0, nil, nil | ||||||
|  | 	}) | ||||||
|  | 	dataChan := make(chan string) | ||||||
|  | 	stopChan := make(chan bool) | ||||||
|  | 	go func() { | ||||||
|  | 		for scanner.Scan() { | ||||||
|  | 			data := scanner.Text() | ||||||
|  | 			if len(data) < 5 { // ignore blank line or wrong format | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if data[:5] != "data:" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			data = data[5:] | ||||||
|  | 			dataChan <- data | ||||||
|  | 		} | ||||||
|  | 		stopChan <- true | ||||||
|  | 	}() | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
|  | 	c.Stream(func(w io.Writer) bool { | ||||||
|  | 		select { | ||||||
|  | 		case data := <-dataChan: | ||||||
|  | 			var TencentResponse TencentChatResponse | ||||||
|  | 			err := json.Unmarshal([]byte(data), &TencentResponse) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error unmarshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			response := streamResponseTencent2OpenAI(&TencentResponse) | ||||||
|  | 			if len(response.Choices) != 0 { | ||||||
|  | 				responseText += response.Choices[0].Delta.Content | ||||||
|  | 			} | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	err := resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||||||
|  | 	} | ||||||
|  | 	return nil, responseText | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	var TencentResponse TencentChatResponse | ||||||
|  | 	responseBody, err := io.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = resp.Body.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	err = json.Unmarshal(responseBody, &TencentResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	if TencentResponse.Error.Code != 0 { | ||||||
|  | 		return &OpenAIErrorWithStatusCode{ | ||||||
|  | 			OpenAIError: OpenAIError{ | ||||||
|  | 				Message: TencentResponse.Error.Message, | ||||||
|  | 				Code:    TencentResponse.Error.Code, | ||||||
|  | 			}, | ||||||
|  | 			StatusCode: resp.StatusCode, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | 	fullTextResponse := responseTencent2OpenAI(&TencentResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	c.Writer.WriteHeader(resp.StatusCode) | ||||||
|  | 	_, err = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &fullTextResponse.Usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { | ||||||
|  | 	parts := strings.Split(config, "|") | ||||||
|  | 	if len(parts) != 3 { | ||||||
|  | 		err = errors.New("invalid tencent config") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	appId, err = strconv.ParseInt(parts[0], 10, 64) | ||||||
|  | 	secretId = parts[1] | ||||||
|  | 	secretKey = parts[2] | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getTencentSign(req TencentChatRequest, secretKey string) string { | ||||||
|  | 	params := make([]string, 0) | ||||||
|  | 	params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) | ||||||
|  | 	params = append(params, "secret_id="+req.SecretId) | ||||||
|  | 	params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) | ||||||
|  | 	params = append(params, "query_id="+req.QueryID) | ||||||
|  | 	params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) | ||||||
|  | 	params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) | ||||||
|  | 	params = append(params, "stream="+strconv.Itoa(req.Stream)) | ||||||
|  | 	params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) | ||||||
|  |  | ||||||
|  | 	var messageStr string | ||||||
|  | 	for _, msg := range req.Messages { | ||||||
|  | 		messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) | ||||||
|  | 	} | ||||||
|  | 	messageStr = strings.TrimSuffix(messageStr, ",") | ||||||
|  | 	params = append(params, "messages=["+messageStr+"]") | ||||||
|  |  | ||||||
|  | 	sort.Sort(sort.StringSlice(params)) | ||||||
|  | 	url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") | ||||||
|  | 	mac := hmac.New(sha1.New, []byte(secretKey)) | ||||||
|  | 	signURL := url | ||||||
|  | 	mac.Write([]byte(signURL)) | ||||||
|  | 	sign := mac.Sum([]byte(nil)) | ||||||
|  | 	return base64.StdEncoding.EncodeToString(sign) | ||||||
|  | } | ||||||
| @@ -2,6 +2,7 @@ package controller | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -22,6 +23,8 @@ const ( | |||||||
| 	APITypeZhipu | 	APITypeZhipu | ||||||
| 	APITypeAli | 	APITypeAli | ||||||
| 	APITypeXunfei | 	APITypeXunfei | ||||||
|  | 	APITypeAIProxyLibrary | ||||||
|  | 	APITypeTencent | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var httpClient *http.Client | var httpClient *http.Client | ||||||
| @@ -36,6 +39,7 @@ func init() { | |||||||
|  |  | ||||||
| func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||||||
| 	channelType := c.GetInt("channel") | 	channelType := c.GetInt("channel") | ||||||
|  | 	channelId := c.GetInt("channel_id") | ||||||
| 	tokenId := c.GetInt("token_id") | 	tokenId := c.GetInt("token_id") | ||||||
| 	userId := c.GetInt("id") | 	userId := c.GetInt("id") | ||||||
| 	consumeQuota := c.GetBool("consume_quota") | 	consumeQuota := c.GetBool("consume_quota") | ||||||
| @@ -104,6 +108,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		apiType = APITypeAli | 		apiType = APITypeAli | ||||||
| 	case common.ChannelTypeXunfei: | 	case common.ChannelTypeXunfei: | ||||||
| 		apiType = APITypeXunfei | 		apiType = APITypeXunfei | ||||||
|  | 	case common.ChannelTypeAIProxyLibrary: | ||||||
|  | 		apiType = APITypeAIProxyLibrary | ||||||
|  | 	case common.ChannelTypeTencent: | ||||||
|  | 		apiType = APITypeTencent | ||||||
| 	} | 	} | ||||||
| 	baseURL := common.ChannelBaseURLs[channelType] | 	baseURL := common.ChannelBaseURLs[channelType] | ||||||
| 	requestURL := c.Request.URL.String() | 	requestURL := c.Request.URL.String() | ||||||
| @@ -111,6 +119,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		baseURL = c.GetString("base_url") | 		baseURL = c.GetString("base_url") | ||||||
| 	} | 	} | ||||||
| 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||||||
|  | 	if channelType == common.ChannelTypeOpenAI { | ||||||
|  | 		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { | ||||||
|  | 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeOpenAI: | 	case APITypeOpenAI: | ||||||
| 		if channelType == common.ChannelTypeAzure { | 		if channelType == common.ChannelTypeAzure { | ||||||
| @@ -171,6 +184,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" | ||||||
|  | 		if relayMode == RelayModeEmbeddings { | ||||||
|  | 			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" | ||||||
|  | 		} | ||||||
|  | 	case APITypeTencent: | ||||||
|  | 		fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) | ||||||
| 	} | 	} | ||||||
| 	var promptTokens int | 	var promptTokens int | ||||||
| 	var completionTokens int | 	var completionTokens int | ||||||
| @@ -194,6 +214,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
|  | 	if userQuota-preConsumedQuota < 0 { | ||||||
|  | 		return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||||
|  | 	} | ||||||
| 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | 	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | 		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) | ||||||
| @@ -202,6 +225,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		// in this case, we do not pre-consume quota | 		// in this case, we do not pre-consume quota | ||||||
| 		// because the user has enough quota | 		// because the user has enough quota | ||||||
| 		preConsumedQuota = 0 | 		preConsumedQuota = 0 | ||||||
|  | 		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) | ||||||
| 	} | 	} | ||||||
| 	if consumeQuota && preConsumedQuota > 0 { | 	if consumeQuota && preConsumedQuota > 0 { | ||||||
| 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) | ||||||
| @@ -257,8 +281,41 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		} | 		} | ||||||
| 		requestBody = bytes.NewBuffer(jsonStr) | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
| 	case APITypeAli: | 	case APITypeAli: | ||||||
| 		aliRequest := requestOpenAI2Ali(textRequest) | 		var jsonStr []byte | ||||||
| 		jsonStr, err := json.Marshal(aliRequest) | 		var err error | ||||||
|  | 		switch relayMode { | ||||||
|  | 		case RelayModeEmbeddings: | ||||||
|  | 			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) | ||||||
|  | 			jsonStr, err = json.Marshal(aliEmbeddingRequest) | ||||||
|  | 		default: | ||||||
|  | 			aliRequest := requestOpenAI2Ali(textRequest) | ||||||
|  | 			jsonStr, err = json.Marshal(aliRequest) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeTencent: | ||||||
|  | 		apiKey := c.Request.Header.Get("Authorization") | ||||||
|  | 		apiKey = strings.TrimPrefix(apiKey, "Bearer ") | ||||||
|  | 		appId, secretId, secretKey, err := parseTencentConfig(apiKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		tencentRequest := requestOpenAI2Tencent(textRequest) | ||||||
|  | 		tencentRequest.AppId = appId | ||||||
|  | 		tencentRequest.SecretId = secretId | ||||||
|  | 		jsonStr, err := json.Marshal(tencentRequest) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
|  | 		} | ||||||
|  | 		sign := getTencentSign(*tencentRequest, secretKey) | ||||||
|  | 		c.Request.Header.Set("Authorization", sign) | ||||||
|  | 		requestBody = bytes.NewBuffer(jsonStr) | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) | ||||||
|  | 		aiProxyLibraryRequest.LibraryId = c.GetString("library_id") | ||||||
|  | 		jsonStr, err := json.Marshal(aiProxyLibraryRequest) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||||||
| 		} | 		} | ||||||
| @@ -302,6 +359,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			if textRequest.Stream { | 			if textRequest.Stream { | ||||||
| 				req.Header.Set("X-DashScope-SSE", "enable") | 				req.Header.Set("X-DashScope-SSE", "enable") | ||||||
| 			} | 			} | ||||||
|  | 		case APITypeTencent: | ||||||
|  | 			req.Header.Set("Authorization", apiKey) | ||||||
|  | 		default: | ||||||
|  | 			req.Header.Set("Authorization", "Bearer "+apiKey) | ||||||
| 		} | 		} | ||||||
| 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||||||
| 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 		req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||||||
| @@ -321,15 +382,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | 		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") | ||||||
|  |  | ||||||
| 		if resp.StatusCode != http.StatusOK { | 		if resp.StatusCode != http.StatusOK { | ||||||
|  | 			if preConsumedQuota != 0 { | ||||||
|  | 				go func(ctx context.Context) { | ||||||
|  | 					// return pre-consumed quota | ||||||
|  | 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) | ||||||
|  | 					if err != nil { | ||||||
|  | 						common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) | ||||||
|  | 					} | ||||||
|  | 				}(c.Request.Context()) | ||||||
|  | 			} | ||||||
| 			return relayErrorHandler(resp) | 			return relayErrorHandler(resp) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var textResponse TextResponse | 	var textResponse TextResponse | ||||||
| 	tokenName := c.GetString("token_name") | 	tokenName := c.GetString("token_name") | ||||||
| 	channelId := c.GetInt("channel_id") |  | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func(ctx context.Context) { | ||||||
| 		// c.Writer.Flush() | 		// c.Writer.Flush() | ||||||
| 		go func() { | 		go func() { | ||||||
| 			if consumeQuota { | 			if consumeQuota { | ||||||
| @@ -352,22 +421,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 				quotaDelta := quota - preConsumedQuota | 				quotaDelta := quota - preConsumedQuota | ||||||
| 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("error consuming token remain quota: " + err.Error()) | 					common.LogError(ctx, "error consuming token remain quota: "+err.Error()) | ||||||
| 				} | 				} | ||||||
| 				err = model.CacheUpdateUserQuota(userId) | 				err = model.CacheUpdateUserQuota(userId) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					common.SysError("error update user quota cache: " + err.Error()) | 					common.LogError(ctx, "error update user quota cache: "+err.Error()) | ||||||
| 				} | 				} | ||||||
| 				if quota != 0 { | 				if quota != 0 { | ||||||
| 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||||
| 					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | 					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) | ||||||
| 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||||||
|  |  | ||||||
| 					model.UpdateChannelUsedQuota(channelId, quota) | 					model.UpdateChannelUsedQuota(channelId, quota) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 	}() | 	}(c.Request.Context()) | ||||||
| 	switch apiType { | 	switch apiType { | ||||||
| 	case APITypeOpenAI: | 	case APITypeOpenAI: | ||||||
| 		if isStream { | 		if isStream { | ||||||
| @@ -488,7 +556,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			err, usage := aliHandler(c, resp) | 			var err *OpenAIErrorWithStatusCode | ||||||
|  | 			var usage *Usage | ||||||
|  | 			switch relayMode { | ||||||
|  | 			case RelayModeEmbeddings: | ||||||
|  | 				err, usage = aliEmbeddingHandler(c, resp) | ||||||
|  | 			default: | ||||||
|  | 				err, usage = aliHandler(c, resp) | ||||||
|  | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -498,14 +573,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
| 	case APITypeXunfei: | 	case APITypeXunfei: | ||||||
|  | 		auth := c.Request.Header.Get("Authorization") | ||||||
|  | 		auth = strings.TrimPrefix(auth, "Bearer ") | ||||||
|  | 		splits := strings.Split(auth, "|") | ||||||
|  | 		if len(splits) != 3 { | ||||||
|  | 			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | ||||||
|  | 		} | ||||||
|  | 		var err *OpenAIErrorWithStatusCode | ||||||
|  | 		var usage *Usage | ||||||
| 		if isStream { | 		if isStream { | ||||||
| 			auth := c.Request.Header.Get("Authorization") | 			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
| 			auth = strings.TrimPrefix(auth, "Bearer ") | 		} else { | ||||||
| 			splits := strings.Split(auth, "|") | 			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) | ||||||
| 			if len(splits) != 3 { | 		} | ||||||
| 				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) | 		if err != nil { | ||||||
| 			} | 			return err | ||||||
| 			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) | 		} | ||||||
|  | 		if usage != nil { | ||||||
|  | 			textResponse.Usage = *usage | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	case APITypeAIProxyLibrary: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, usage := aiProxyLibraryStreamHandler(c, resp) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| @@ -514,7 +604,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | |||||||
| 			} | 			} | ||||||
| 			return nil | 			return nil | ||||||
| 		} else { | 		} else { | ||||||
| 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) | 			err, usage := aiProxyLibraryHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	case APITypeTencent: | ||||||
|  | 		if isStream { | ||||||
|  | 			err, responseText := tencentStreamHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			textResponse.Usage.PromptTokens = promptTokens | ||||||
|  | 			textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) | ||||||
|  | 			return nil | ||||||
|  | 		} else { | ||||||
|  | 			err, usage := tencentHandler(c, resp) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if usage != nil { | ||||||
|  | 				textResponse.Usage = *usage | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
| 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | ||||||
|   | |||||||
| @@ -9,44 +9,53 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var stopFinishReason = "stop" | var stopFinishReason = "stop" | ||||||
|  |  | ||||||
|  | // tokenEncoderMap won't grow after initialization | ||||||
| var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} | ||||||
|  | var defaultTokenEncoder *tiktoken.Tiktoken | ||||||
|  |  | ||||||
| func InitTokenEncoders() { | func InitTokenEncoders() { | ||||||
| 	common.SysLog("initializing token encoders") | 	common.SysLog("initializing token encoders") | ||||||
| 	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | 	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) | 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) | ||||||
|  | 	} | ||||||
|  | 	defaultTokenEncoder = gpt35TokenEncoder | ||||||
|  | 	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) | ||||||
| 	} | 	} | ||||||
| 	for model, _ := range common.ModelRatio { | 	for model, _ := range common.ModelRatio { | ||||||
| 		tokenEncoder, err := tiktoken.EncodingForModel(model) | 		if strings.HasPrefix(model, "gpt-3.5") { | ||||||
| 		if err != nil { | 			tokenEncoderMap[model] = gpt35TokenEncoder | ||||||
| 			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) | 		} else if strings.HasPrefix(model, "gpt-4") { | ||||||
| 			tokenEncoderMap[model] = fallbackTokenEncoder | 			tokenEncoderMap[model] = gpt4TokenEncoder | ||||||
| 			continue | 		} else { | ||||||
|  | 			tokenEncoderMap[model] = nil | ||||||
| 		} | 		} | ||||||
| 		tokenEncoderMap[model] = tokenEncoder |  | ||||||
| 	} | 	} | ||||||
| 	common.SysLog("token encoders initialized") | 	common.SysLog("token encoders initialized") | ||||||
| } | } | ||||||
|  |  | ||||||
| func getTokenEncoder(model string) *tiktoken.Tiktoken { | func getTokenEncoder(model string) *tiktoken.Tiktoken { | ||||||
| 	if tokenEncoder, ok := tokenEncoderMap[model]; ok { | 	tokenEncoder, ok := tokenEncoderMap[model] | ||||||
|  | 	if ok && tokenEncoder != nil { | ||||||
| 		return tokenEncoder | 		return tokenEncoder | ||||||
| 	} | 	} | ||||||
| 	tokenEncoder, err := tiktoken.EncodingForModel(model) | 	if ok { | ||||||
| 	if err != nil { | 		tokenEncoder, err := tiktoken.EncodingForModel(model) | ||||||
| 		common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) |  | ||||||
| 		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) | 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) | ||||||
|  | 			tokenEncoder = defaultTokenEncoder | ||||||
| 		} | 		} | ||||||
|  | 		tokenEncoderMap[model] = tokenEncoder | ||||||
|  | 		return tokenEncoder | ||||||
| 	} | 	} | ||||||
| 	tokenEncoderMap[model] = tokenEncoder | 	return defaultTokenEncoder | ||||||
| 	return tokenEncoder |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { | ||||||
| @@ -146,7 +155,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr | |||||||
| 		StatusCode: resp.StatusCode, | 		StatusCode: resp.StatusCode, | ||||||
| 		OpenAIError: OpenAIError{ | 		OpenAIError: OpenAIError{ | ||||||
| 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), | 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), | ||||||
| 			Type:    "one_api_error", | 			Type:    "upstream_error", | ||||||
| 			Code:    "bad_response_status_code", | 			Code:    "bad_response_status_code", | ||||||
| 			Param:   strconv.Itoa(resp.StatusCode), | 			Param:   strconv.Itoa(resp.StatusCode), | ||||||
| 		}, | 		}, | ||||||
|   | |||||||
| @@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { | |||||||
| 			Role:    "assistant", | 			Role:    "assistant", | ||||||
| 			Content: response.Payload.Choices.Text[0].Content, | 			Content: response.Payload.Choices.Text[0].Content, | ||||||
| 		}, | 		}, | ||||||
|  | 		FinishReason: stopFinishReason, | ||||||
| 	} | 	} | ||||||
| 	fullTextResponse := OpenAITextResponse{ | 	fullTextResponse := OpenAITextResponse{ | ||||||
| 		Object:  "chat.completion", | 		Object:  "chat.completion", | ||||||
| @@ -177,33 +178,85 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | |||||||
| } | } | ||||||
|  |  | ||||||
| func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||||
|  | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||||
|  | 	} | ||||||
|  | 	setEventStreamHeaders(c) | ||||||
| 	var usage Usage | 	var usage Usage | ||||||
| 	query := c.Request.URL.Query() | 	c.Stream(func(w io.Writer) bool { | ||||||
| 	apiVersion := query.Get("api-version") | 		select { | ||||||
| 	if apiVersion == "" { | 		case xunfeiResponse := <-dataChan: | ||||||
| 		apiVersion = c.GetString("api_version") | 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||||
|  | 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||||
|  | 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||||
|  | 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) | ||||||
|  | 			jsonResponse, err := json.Marshal(response) | ||||||
|  | 			if err != nil { | ||||||
|  | 				common.SysError("error marshalling stream response: " + err.Error()) | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||||
|  | 			return true | ||||||
|  | 		case <-stopChan: | ||||||
|  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { | ||||||
|  | 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) | ||||||
|  | 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	if apiVersion == "" { | 	var usage Usage | ||||||
| 		apiVersion = "v1.1" | 	var content string | ||||||
| 		common.SysLog("api_version not found, use default: " + apiVersion) | 	var xunfeiResponse XunfeiChatResponse | ||||||
|  | 	stop := false | ||||||
|  | 	for !stop { | ||||||
|  | 		select { | ||||||
|  | 		case xunfeiResponse = <-dataChan: | ||||||
|  | 			if len(xunfeiResponse.Payload.Choices.Text) == 0 { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			content += xunfeiResponse.Payload.Choices.Text[0].Content | ||||||
|  | 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens | ||||||
|  | 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens | ||||||
|  | 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens | ||||||
|  | 		case stop = <-stopChan: | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	domain := "general" |  | ||||||
| 	if apiVersion == "v2.1" { | 	xunfeiResponse.Payload.Choices.Text[0].Content = content | ||||||
| 		domain = "generalv2" |  | ||||||
|  | 	response := responseXunfei2OpenAI(&xunfeiResponse) | ||||||
|  | 	jsonResponse, err := json.Marshal(response) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||||
| 	} | 	} | ||||||
| 	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) | 	c.Writer.Header().Set("Content-Type", "application/json") | ||||||
|  | 	_, _ = c.Writer.Write(jsonResponse) | ||||||
|  | 	return nil, &usage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { | ||||||
| 	d := websocket.Dialer{ | 	d := websocket.Dialer{ | ||||||
| 		HandshakeTimeout: 5 * time.Second, | 		HandshakeTimeout: 5 * time.Second, | ||||||
| 	} | 	} | ||||||
| 	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) | 	conn, resp, err := d.Dial(authUrl, nil) | ||||||
| 	if err != nil || resp.StatusCode != 101 { | 	if err != nil || resp.StatusCode != 101 { | ||||||
| 		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
| 	data := requestOpenAI2Xunfei(textRequest, appId, domain) | 	data := requestOpenAI2Xunfei(textRequest, appId, domain) | ||||||
| 	err = conn.WriteJSON(data) | 	err = conn.WriteJSON(data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	dataChan := make(chan XunfeiChatResponse) | 	dataChan := make(chan XunfeiChatResponse) | ||||||
| 	stopChan := make(chan bool) | 	stopChan := make(chan bool) | ||||||
| 	go func() { | 	go func() { | ||||||
| @@ -230,61 +283,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId | |||||||
| 		} | 		} | ||||||
| 		stopChan <- true | 		stopChan <- true | ||||||
| 	}() | 	}() | ||||||
| 	setEventStreamHeaders(c) |  | ||||||
| 	c.Stream(func(w io.Writer) bool { | 	return dataChan, stopChan, nil | ||||||
| 		select { |  | ||||||
| 		case xunfeiResponse := <-dataChan: |  | ||||||
| 			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens |  | ||||||
| 			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens |  | ||||||
| 			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens |  | ||||||
| 			response := streamResponseXunfei2OpenAI(&xunfeiResponse) |  | ||||||
| 			jsonResponse, err := json.Marshal(response) |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("error marshalling stream response: " + err.Error()) |  | ||||||
| 				return true |  | ||||||
| 			} |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) |  | ||||||
| 			return true |  | ||||||
| 		case <-stopChan: |  | ||||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 	return nil, &usage |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { | ||||||
| 	var xunfeiResponse XunfeiChatResponse | 	query := c.Request.URL.Query() | ||||||
| 	responseBody, err := io.ReadAll(resp.Body) | 	apiVersion := query.Get("api-version") | ||||||
| 	if err != nil { | 	if apiVersion == "" { | ||||||
| 		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | 		apiVersion = c.GetString("api_version") | ||||||
| 	} | 	} | ||||||
| 	err = resp.Body.Close() | 	if apiVersion == "" { | ||||||
| 	if err != nil { | 		apiVersion = "v1.1" | ||||||
| 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | 		common.SysLog("api_version not found, use default: " + apiVersion) | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(responseBody, &xunfeiResponse) | 	domain := "general" | ||||||
| 	if err != nil { | 	if apiVersion == "v2.1" { | ||||||
| 		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | 		domain = "generalv2" | ||||||
| 	} | 	} | ||||||
| 	if xunfeiResponse.Header.Code != 0 { | 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||||
| 		return &OpenAIErrorWithStatusCode{ | 	return domain, authUrl | ||||||
| 			OpenAIError: OpenAIError{ |  | ||||||
| 				Message: xunfeiResponse.Header.Message, |  | ||||||
| 				Type:    "xunfei_error", |  | ||||||
| 				Param:   "", |  | ||||||
| 				Code:    xunfeiResponse.Header.Code, |  | ||||||
| 			}, |  | ||||||
| 			StatusCode: resp.StatusCode, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) |  | ||||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil |  | ||||||
| 	} |  | ||||||
| 	c.Writer.Header().Set("Content-Type", "application/json") |  | ||||||
| 	c.Writer.WriteHeader(resp.StatusCode) |  | ||||||
| 	_, err = c.Writer.Write(jsonResponse) |  | ||||||
| 	return nil, &fullTextResponse.Usage |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -44,6 +44,25 @@ type GeneralOpenAIRequest struct { | |||||||
| 	Functions   any       `json:"functions,omitempty"` | 	Functions   any       `json:"functions,omitempty"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (r GeneralOpenAIRequest) ParseInput() []string { | ||||||
|  | 	if r.Input == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	var input []string | ||||||
|  | 	switch r.Input.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		input = []string{r.Input.(string)} | ||||||
|  | 	case []any: | ||||||
|  | 		input = make([]string, 0, len(r.Input.([]any))) | ||||||
|  | 		for _, item := range r.Input.([]any) { | ||||||
|  | 			if str, ok := item.(string); ok { | ||||||
|  | 				input = append(input, str) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return input | ||||||
|  | } | ||||||
|  |  | ||||||
| type ChatRequest struct { | type ChatRequest struct { | ||||||
| 	Model     string    `json:"model"` | 	Model     string    `json:"model"` | ||||||
| 	Messages  []Message `json:"messages"` | 	Messages  []Message `json:"messages"` | ||||||
| @@ -177,6 +196,7 @@ func Relay(c *gin.Context) { | |||||||
| 		err = relayTextHelper(c, relayMode) | 		err = relayTextHelper(c, relayMode) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		requestId := c.GetString(common.RequestIdKey) | ||||||
| 		retryTimesStr := c.Query("retry") | 		retryTimesStr := c.Query("retry") | ||||||
| 		retryTimes, _ := strconv.Atoi(retryTimesStr) | 		retryTimes, _ := strconv.Atoi(retryTimesStr) | ||||||
| 		if retryTimesStr == "" { | 		if retryTimesStr == "" { | ||||||
| @@ -188,12 +208,13 @@ func Relay(c *gin.Context) { | |||||||
| 			if err.StatusCode == http.StatusTooManyRequests { | 			if err.StatusCode == http.StatusTooManyRequests { | ||||||
| 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | 				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" | ||||||
| 			} | 			} | ||||||
|  | 			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) | ||||||
| 			c.JSON(err.StatusCode, gin.H{ | 			c.JSON(err.StatusCode, gin.H{ | ||||||
| 				"error": err.OpenAIError, | 				"error": err.OpenAIError, | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		channelId := c.GetInt("channel_id") | 		channelId := c.GetInt("channel_id") | ||||||
| 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) | ||||||
| 		// https://platform.openai.com/docs/guides/error-codes/api-errors | 		// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||||
| 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { | ||||||
| 			channelId := c.GetInt("channel_id") | 			channelId := c.GetInt("channel_id") | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.mod
									
									
									
									
									
								
							| @@ -15,8 +15,9 @@ require ( | |||||||
| 	github.com/google/uuid v1.3.0 | 	github.com/google/uuid v1.3.0 | ||||||
| 	github.com/gorilla/websocket v1.5.0 | 	github.com/gorilla/websocket v1.5.0 | ||||||
| 	github.com/pkoukk/tiktoken-go v0.1.5 | 	github.com/pkoukk/tiktoken-go v0.1.5 | ||||||
| 	golang.org/x/crypto v0.9.0 | 	golang.org/x/crypto v0.14.0 | ||||||
| 	gorm.io/driver/mysql v1.4.3 | 	gorm.io/driver/mysql v1.4.3 | ||||||
|  | 	gorm.io/driver/postgres v1.5.2 | ||||||
| 	gorm.io/driver/sqlite v1.4.3 | 	gorm.io/driver/sqlite v1.4.3 | ||||||
| 	gorm.io/gorm v1.25.0 | 	gorm.io/gorm v1.25.0 | ||||||
| ) | ) | ||||||
| @@ -52,10 +53,9 @@ require ( | |||||||
| 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect | ||||||
| 	github.com/ugorji/go/codec v1.2.11 // indirect | 	github.com/ugorji/go/codec v1.2.11 // indirect | ||||||
| 	golang.org/x/arch v0.3.0 // indirect | 	golang.org/x/arch v0.3.0 // indirect | ||||||
| 	golang.org/x/net v0.10.0 // indirect | 	golang.org/x/net v0.17.0 // indirect | ||||||
| 	golang.org/x/sys v0.8.0 // indirect | 	golang.org/x/sys v0.13.0 // indirect | ||||||
| 	golang.org/x/text v0.9.0 // indirect | 	golang.org/x/text v0.13.0 // indirect | ||||||
| 	google.golang.org/protobuf v1.30.0 // indirect | 	google.golang.org/protobuf v1.30.0 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
| 	gorm.io/driver/postgres v1.5.2 // indirect |  | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								go.sum
									
									
									
									
									
								
							| @@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu | |||||||
| golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= | ||||||
| golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= | ||||||
| golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||||
| golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= | golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= | ||||||
| golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= | golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= | ||||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||||
| golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= | golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= | ||||||
| golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= | golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= | ||||||
| golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| @@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc | |||||||
| golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= | golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= | ||||||
| golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||||
| golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= | golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= | ||||||
| golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= | golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||||
| @@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp | |||||||
| gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= | ||||||
| gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= | ||||||
| gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= | ||||||
| gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= |  | ||||||
| gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= | ||||||
| gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= | ||||||
| gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= | ||||||
|   | |||||||
| @@ -523,5 +523,6 @@ | |||||||
|   "按照如下格式输入:": "Enter in the following format:", |   "按照如下格式输入:": "Enter in the following format:", | ||||||
|   "模型版本": "Model version", |   "模型版本": "Model version", | ||||||
|   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", |   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", | ||||||
|   "点击查看": "click to view" |   "点击查看": "click to view", | ||||||
|  |   "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								main.go
									
									
									
									
									
								
							| @@ -2,6 +2,7 @@ package main | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"embed" | 	"embed" | ||||||
|  | 	"fmt" | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-contrib/sessions/cookie" | 	"github.com/gin-contrib/sessions/cookie" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| @@ -21,7 +22,7 @@ var buildFS embed.FS | |||||||
| var indexPage []byte | var indexPage []byte | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| 	common.SetupGinLog() | 	common.SetupLogger() | ||||||
| 	common.SysLog("One API " + common.Version + " started") | 	common.SysLog("One API " + common.Version + " started") | ||||||
| 	if os.Getenv("GIN_MODE") != "debug" { | 	if os.Getenv("GIN_MODE") != "debug" { | ||||||
| 		gin.SetMode(gin.ReleaseMode) | 		gin.SetMode(gin.ReleaseMode) | ||||||
| @@ -50,18 +51,17 @@ func main() { | |||||||
| 	// Initialize options | 	// Initialize options | ||||||
| 	model.InitOptionMap() | 	model.InitOptionMap() | ||||||
| 	if common.RedisEnabled { | 	if common.RedisEnabled { | ||||||
|  | 		// for compatibility with old versions | ||||||
|  | 		common.MemoryCacheEnabled = true | ||||||
|  | 	} | ||||||
|  | 	if common.MemoryCacheEnabled { | ||||||
|  | 		common.SysLog("memory cache enabled") | ||||||
|  | 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) | ||||||
| 		model.InitChannelCache() | 		model.InitChannelCache() | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("SYNC_FREQUENCY") != "" { | 	if common.MemoryCacheEnabled { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) | 		go model.SyncOptions(common.SyncFrequency) | ||||||
| 		if err != nil { | 		go model.SyncChannelCache(common.SyncFrequency) | ||||||
| 			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		common.SyncFrequency = frequency |  | ||||||
| 		go model.SyncOptions(frequency) |  | ||||||
| 		if common.RedisEnabled { |  | ||||||
| 			go model.SyncChannelCache(frequency) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||||
| @@ -77,14 +77,20 @@ func main() { | |||||||
| 		} | 		} | ||||||
| 		go controller.AutomaticallyTestChannels(frequency) | 		go controller.AutomaticallyTestChannels(frequency) | ||||||
| 	} | 	} | ||||||
|  | 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { | ||||||
|  | 		common.BatchUpdateEnabled = true | ||||||
|  | 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") | ||||||
|  | 		model.InitBatchUpdater() | ||||||
|  | 	} | ||||||
| 	controller.InitTokenEncoders() | 	controller.InitTokenEncoders() | ||||||
|  |  | ||||||
| 	// Initialize HTTP server | 	// Initialize HTTP server | ||||||
| 	server := gin.Default() | 	server := gin.New() | ||||||
|  | 	server.Use(gin.Recovery()) | ||||||
| 	// This will cause SSE not to work!!! | 	// This will cause SSE not to work!!! | ||||||
| 	//server.Use(gzip.Gzip(gzip.DefaultCompression)) | 	//server.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||||
| 	server.Use(middleware.CORS()) | 	server.Use(middleware.RequestId()) | ||||||
|  | 	middleware.SetUpLogger(server) | ||||||
| 	// Initialize session store | 	// Initialize session store | ||||||
| 	store := cookie.NewStore([]byte(common.SessionSecret)) | 	store := cookie.NewStore([]byte(common.SessionSecret)) | ||||||
| 	server.Use(sessions.Sessions("session", store)) | 	server.Use(sessions.Sessions("session", store)) | ||||||
|   | |||||||
| @@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 		key = parts[0] | 		key = parts[0] | ||||||
| 		token, err := model.ValidateUserToken(key) | 		token, err := model.ValidateUserToken(key) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			c.JSON(http.StatusUnauthorized, gin.H{ | 			abortWithMessage(c, http.StatusUnauthorized, err.Error()) | ||||||
| 				"error": gin.H{ |  | ||||||
| 					"message": err.Error(), |  | ||||||
| 					"type":    "one_api_error", |  | ||||||
| 				}, |  | ||||||
| 			}) |  | ||||||
| 			c.Abort() |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if !model.CacheIsUserEnabled(token.UserId) { | 		userEnabled, err := model.CacheIsUserEnabled(token.UserId) | ||||||
| 			c.JSON(http.StatusForbidden, gin.H{ | 		if err != nil { | ||||||
| 				"error": gin.H{ | 			abortWithMessage(c, http.StatusInternalServerError, err.Error()) | ||||||
| 					"message": "用户已被封禁", | 			return | ||||||
| 					"type":    "one_api_error", | 		} | ||||||
| 				}, | 		if !userEnabled { | ||||||
| 			}) | 			abortWithMessage(c, http.StatusForbidden, "用户已被封禁") | ||||||
| 			c.Abort() |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		c.Set("id", token.UserId) | 		c.Set("id", token.UserId) | ||||||
| @@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { | |||||||
| 			if model.IsAdmin(token.UserId) { | 			if model.IsAdmin(token.UserId) { | ||||||
| 				c.Set("channelId", parts[1]) | 				c.Set("channelId", parts[1]) | ||||||
| 			} else { | 			} else { | ||||||
| 				c.JSON(http.StatusForbidden, gin.H{ | 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "普通用户不支持指定渠道", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) { | |||||||
| 		if ok { | 		if ok { | ||||||
| 			id, err := strconv.Atoi(channelId.(string)) | 			id, err := strconv.Atoi(channelId.(string)) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的渠道 ID", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			channel, err = model.GetChannelById(id, true) | 			channel, err = model.GetChannelById(id, true) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的渠道 ID", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if channel.Status != common.ChannelStatusEnabled { | 			if channel.Status != common.ChannelStatusEnabled { | ||||||
| 				c.JSON(http.StatusForbidden, gin.H{ | 				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "该渠道已被禁用", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| @@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) { | |||||||
| 				err = common.UnmarshalBodyReusable(c, &modelRequest) | 				err = common.UnmarshalBodyReusable(c, &modelRequest) | ||||||
| 			} | 			} | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				c.JSON(http.StatusBadRequest, gin.H{ | 				abortWithMessage(c, http.StatusBadRequest, "无效的请求") | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": "无效的请求", |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { | ||||||
| @@ -99,24 +75,23 @@ func Distribute() func(c *gin.Context) { | |||||||
| 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) | ||||||
| 					message = "数据库一致性已被破坏,请联系管理员" | 					message = "数据库一致性已被破坏,请联系管理员" | ||||||
| 				} | 				} | ||||||
| 				c.JSON(http.StatusServiceUnavailable, gin.H{ | 				abortWithMessage(c, http.StatusServiceUnavailable, message) | ||||||
| 					"error": gin.H{ |  | ||||||
| 						"message": message, |  | ||||||
| 						"type":    "one_api_error", |  | ||||||
| 					}, |  | ||||||
| 				}) |  | ||||||
| 				c.Abort() |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		c.Set("channel", channel.Type) | 		c.Set("channel", channel.Type) | ||||||
| 		c.Set("channel_id", channel.Id) | 		c.Set("channel_id", channel.Id) | ||||||
| 		c.Set("channel_name", channel.Name) | 		c.Set("channel_name", channel.Name) | ||||||
| 		c.Set("model_mapping", channel.ModelMapping) | 		c.Set("model_mapping", channel.GetModelMapping()) | ||||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||||
| 		c.Set("base_url", channel.BaseURL) | 		c.Set("base_url", channel.GetBaseURL()) | ||||||
| 		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { | 		switch channel.Type { | ||||||
|  | 		case common.ChannelTypeAzure: | ||||||
| 			c.Set("api_version", channel.Other) | 			c.Set("api_version", channel.Other) | ||||||
|  | 		case common.ChannelTypeXunfei: | ||||||
|  | 			c.Set("api_version", channel.Other) | ||||||
|  | 		case common.ChannelTypeAIProxyLibrary: | ||||||
|  | 			c.Set("library_id", channel.Other) | ||||||
| 		} | 		} | ||||||
| 		c.Next() | 		c.Next() | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func SetUpLogger(server *gin.Engine) { | ||||||
|  | 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||||
|  | 		var requestID string | ||||||
|  | 		if param.Keys != nil { | ||||||
|  | 			requestID = param.Keys[common.RequestIdKey].(string) | ||||||
|  | 		} | ||||||
|  | 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||||
|  | 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||||
|  | 			requestID, | ||||||
|  | 			param.StatusCode, | ||||||
|  | 			param.Latency, | ||||||
|  | 			param.ClientIP, | ||||||
|  | 			param.Method, | ||||||
|  | 			param.Path, | ||||||
|  | 		) | ||||||
|  | 	})) | ||||||
|  | } | ||||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RequestId() func(c *gin.Context) { | ||||||
|  | 	return func(c *gin.Context) { | ||||||
|  | 		id := common.GetTimeString() + common.GetRandomString(8) | ||||||
|  | 		c.Set(common.RequestIdKey, id) | ||||||
|  | 		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) | ||||||
|  | 		c.Request = c.Request.WithContext(ctx) | ||||||
|  | 		c.Header(common.RequestIdKey, id) | ||||||
|  | 		c.Next() | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"one-api/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||||
|  | 	c.JSON(statusCode, gin.H{ | ||||||
|  | 		"error": gin.H{ | ||||||
|  | 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), | ||||||
|  | 			"type":    "one_api_error", | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	c.Abort() | ||||||
|  | 	common.LogError(c.Request.Context(), message) | ||||||
|  | } | ||||||
| @@ -10,15 +10,18 @@ type Ability struct { | |||||||
| 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | ||||||
| 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | ||||||
| 	Enabled   bool   `json:"enabled"` | 	Enabled   bool   `json:"enabled"` | ||||||
|  | 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	ability := Ability{} | 	ability := Ability{} | ||||||
| 	var err error = nil | 	var err error = nil | ||||||
|  | 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) | ||||||
|  | 	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) | ||||||
| 	if common.UsingSQLite { | 	if common.UsingSQLite { | ||||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error | 		err = channelQuery.Order("RANDOM()").First(&ability).Error | ||||||
| 	} else { | 	} else { | ||||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error | 		err = channelQuery.Order("RAND()").First(&ability).Error | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -40,6 +43,7 @@ func (channel *Channel) AddAbilities() error { | |||||||
| 				Model:     model, | 				Model:     model, | ||||||
| 				ChannelId: channel.Id, | 				ChannelId: channel.Id, | ||||||
| 				Enabled:   channel.Status == common.ChannelStatusEnabled, | 				Enabled:   channel.Status == common.ChannelStatusEnabled, | ||||||
|  | 				Priority:  channel.Priority, | ||||||
| 			} | 			} | ||||||
| 			abilities = append(abilities, ability) | 			abilities = append(abilities, ability) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
|  | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| @@ -103,23 +104,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func CacheIsUserEnabled(userId int) bool { | func CacheIsUserEnabled(userId int) (bool, error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.RedisEnabled { | ||||||
| 		return IsUserEnabled(userId) | 		return IsUserEnabled(userId) | ||||||
| 	} | 	} | ||||||
| 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) | 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) | ||||||
| 	if err != nil { | 	if err == nil { | ||||||
| 		status := common.UserStatusDisabled | 		return enabled == "1", nil | ||||||
| 		if IsUserEnabled(userId) { |  | ||||||
| 			status = common.UserStatusEnabled |  | ||||||
| 		} |  | ||||||
| 		enabled = fmt.Sprintf("%d", status) |  | ||||||
| 		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) |  | ||||||
| 		if err != nil { |  | ||||||
| 			common.SysError("Redis set user enabled error: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	return enabled == "1" |  | ||||||
|  | 	userEnabled, err := IsUserEnabled(userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	enabled = "0" | ||||||
|  | 	if userEnabled { | ||||||
|  | 		enabled = "1" | ||||||
|  | 	} | ||||||
|  | 	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("Redis set user enabled error: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | 	return userEnabled, err | ||||||
| } | } | ||||||
|  |  | ||||||
| var group2model2channels map[string]map[string][]*Channel | var group2model2channels map[string]map[string][]*Channel | ||||||
| @@ -154,6 +160,17 @@ func InitChannelCache() { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// sort by priority | ||||||
|  | 	for group, model2channels := range newGroup2model2channels { | ||||||
|  | 		for model, channels := range model2channels { | ||||||
|  | 			sort.Slice(channels, func(i, j int) bool { | ||||||
|  | 				return channels[i].GetPriority() > channels[j].GetPriority() | ||||||
|  | 			}) | ||||||
|  | 			newGroup2model2channels[group][model] = channels | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	channelSyncLock.Lock() | 	channelSyncLock.Lock() | ||||||
| 	group2model2channels = newGroup2model2channels | 	group2model2channels = newGroup2model2channels | ||||||
| 	channelSyncLock.Unlock() | 	channelSyncLock.Unlock() | ||||||
| @@ -169,7 +186,7 @@ func SyncChannelCache(frequency int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||||
| 	if !common.RedisEnabled { | 	if !common.MemoryCacheEnabled { | ||||||
| 		return GetRandomSatisfiedChannel(group, model) | 		return GetRandomSatisfiedChannel(group, model) | ||||||
| 	} | 	} | ||||||
| 	channelSyncLock.RLock() | 	channelSyncLock.RLock() | ||||||
| @@ -178,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | |||||||
| 	if len(channels) == 0 { | 	if len(channels) == 0 { | ||||||
| 		return nil, errors.New("channel not found") | 		return nil, errors.New("channel not found") | ||||||
| 	} | 	} | ||||||
| 	idx := rand.Intn(len(channels)) | 	endIdx := len(channels) | ||||||
|  | 	// choose by priority | ||||||
|  | 	firstChannel := channels[0] | ||||||
|  | 	if firstChannel.GetPriority() > 0 { | ||||||
|  | 		for i := range channels { | ||||||
|  | 			if channels[i].GetPriority() != firstChannel.GetPriority() { | ||||||
|  | 				endIdx = i | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	idx := rand.Intn(endIdx) | ||||||
| 	return channels[idx], nil | 	return channels[idx], nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -11,18 +11,19 @@ type Channel struct { | |||||||
| 	Key                string  `json:"key" gorm:"not null;index"` | 	Key                string  `json:"key" gorm:"not null;index"` | ||||||
| 	Status             int     `json:"status" gorm:"default:1"` | 	Status             int     `json:"status" gorm:"default:1"` | ||||||
| 	Name               string  `json:"name" gorm:"index"` | 	Name               string  `json:"name" gorm:"index"` | ||||||
| 	Weight             int     `json:"weight"` | 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||||
| 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | ||||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||||
| 	BaseURL            string  `json:"base_url" gorm:"column:base_url"` | 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | ||||||
| 	Other              string  `json:"other"` | 	Other              string  `json:"other"` | ||||||
| 	Balance            float64 `json:"balance"` // in USD | 	Balance            float64 `json:"balance"` // in USD | ||||||
| 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | ||||||
| 	Models             string  `json:"models"` | 	Models             string  `json:"models"` | ||||||
| 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||||
| 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||||
|  | 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||||
| @@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (channel *Channel) GetPriority() int64 { | ||||||
|  | 	if channel.Priority == nil { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	return *channel.Priority | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (channel *Channel) GetBaseURL() string { | ||||||
|  | 	if channel.BaseURL == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return *channel.BaseURL | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (channel *Channel) GetModelMapping() string { | ||||||
|  | 	if channel.ModelMapping == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return *channel.ModelMapping | ||||||
|  | } | ||||||
|  |  | ||||||
| func (channel *Channel) Insert() error { | func (channel *Channel) Insert() error { | ||||||
| 	var err error | 	var err error | ||||||
| 	err = DB.Create(channel).Error | 	err = DB.Create(channel).Error | ||||||
| @@ -141,8 +163,26 @@ func UpdateChannelStatusById(id int, status int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateChannelUsedQuota(id int, quota int) { | func UpdateChannelUsedQuota(id int, quota int) { | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	updateChannelUsedQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateChannelUsedQuota(id int, quota int) { | ||||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to update channel used quota: " + err.Error()) | 		common.SysError("failed to update channel used quota: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func DeleteChannelByStatus(status int64) (int64, error) { | ||||||
|  | 	result := DB.Where("status = ?", status).Delete(&Channel{}) | ||||||
|  | 	return result.RowsAffected, result.Error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DeleteDisabledChannel() (int64, error) { | ||||||
|  | 	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | ||||||
|  | 	return result.RowsAffected, result.Error | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										40
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								model/log.go
									
									
									
									
									
								
							| @@ -1,22 +1,25 @@ | |||||||
| package model | package model | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"one-api/common" | 	"one-api/common" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Log struct { | type Log struct { | ||||||
| 	Id               int    `json:"id"` | 	Id               int    `json:"id;index:idx_created_at_id,priority:1"` | ||||||
| 	UserId           int    `json:"user_id"` | 	UserId           int    `json:"user_id" gorm:"index"` | ||||||
| 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index"` | 	CreatedAt        int64  `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` | ||||||
| 	Type             int    `json:"type" gorm:"index"` | 	Type             int    `json:"type" gorm:"index:idx_created_at_type"` | ||||||
| 	Content          string `json:"content"` | 	Content          string `json:"content"` | ||||||
| 	Username         string `json:"username" gorm:"index;default:''"` | 	Username         string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` | ||||||
| 	TokenName        string `json:"token_name" gorm:"index;default:''"` | 	TokenName        string `json:"token_name" gorm:"index;default:''"` | ||||||
| 	ModelName        string `json:"model_name" gorm:"index;default:''"` | 	ModelName        string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` | ||||||
| 	Quota            int    `json:"quota" gorm:"default:0"` | 	Quota            int    `json:"quota" gorm:"default:0"` | ||||||
| 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"` | ||||||
| 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"` | ||||||
|  | 	ChannelId        int    `json:"channel" gorm:"index"` | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -44,7 +47,8 @@ func RecordLog(userId int, logType int, content string) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||||
|  | 	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||||
| 	if !common.LogConsumeEnabled { | 	if !common.LogConsumeEnabled { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -59,14 +63,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN | |||||||
| 		TokenName:        tokenName, | 		TokenName:        tokenName, | ||||||
| 		ModelName:        modelName, | 		ModelName:        modelName, | ||||||
| 		Quota:            quota, | 		Quota:            quota, | ||||||
|  | 		ChannelId:        channelId, | ||||||
| 	} | 	} | ||||||
| 	err := DB.Create(log).Error | 	err := DB.Create(log).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("failed to record log: " + err.Error()) | 		common.LogError(ctx, "failed to record log: "+err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { | func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { | ||||||
| 	var tx *gorm.DB | 	var tx *gorm.DB | ||||||
| 	if logType == LogTypeUnknown { | 	if logType == LogTypeUnknown { | ||||||
| 		tx = DB | 		tx = DB | ||||||
| @@ -88,6 +93,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName | |||||||
| 	if endTimestamp != 0 { | 	if endTimestamp != 0 { | ||||||
| 		tx = tx.Where("created_at <= ?", endTimestamp) | 		tx = tx.Where("created_at <= ?", endTimestamp) | ||||||
| 	} | 	} | ||||||
|  | 	if channel != 0 { | ||||||
|  | 		tx = tx.Where("channel = ?", channel) | ||||||
|  | 	} | ||||||
| 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error | ||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
| @@ -125,8 +133,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | |||||||
| 	return logs, err | 	return logs, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { | func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||||
| 	tx := DB.Table("logs").Select("sum(quota)") | 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -142,12 +150,15 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| 	if modelName != "" { | 	if modelName != "" { | ||||||
| 		tx = tx.Where("model_name = ?", modelName) | 		tx = tx.Where("model_name = ?", modelName) | ||||||
| 	} | 	} | ||||||
|  | 	if channel != 0 { | ||||||
|  | 		tx = tx.Where("channel = ?", channel) | ||||||
|  | 	} | ||||||
| 	tx.Where("type = ?", LogTypeConsume).Scan("a) | 	tx.Where("type = ?", LogTypeConsume).Scan("a) | ||||||
| 	return quota | 	return quota | ||||||
| } | } | ||||||
|  |  | ||||||
| func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { | ||||||
| 	tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)") | 	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") | ||||||
| 	if username != "" { | 	if username != "" { | ||||||
| 		tx = tx.Where("username = ?", username) | 		tx = tx.Where("username = ?", username) | ||||||
| 	} | 	} | ||||||
| @@ -166,3 +177,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa | |||||||
| 	tx.Where("type = ?", LogTypeConsume).Scan(&token) | 	tx.Where("type = ?", LogTypeConsume).Scan(&token) | ||||||
| 	return token | 	return token | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func DeleteOldLog(targetTimestamp int64) (int64, error) { | ||||||
|  | 	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) | ||||||
|  | 	return result.RowsAffected, result.Error | ||||||
|  | } | ||||||
|   | |||||||
| @@ -81,6 +81,7 @@ func InitDB() (err error) { | |||||||
| 		if !common.IsMasterNode { | 		if !common.IsMasterNode { | ||||||
| 			return nil | 			return nil | ||||||
| 		} | 		} | ||||||
|  | 		common.SysLog("database migration started") | ||||||
| 		err = db.AutoMigrate(&Channel{}) | 		err = db.AutoMigrate(&Channel{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
|   | |||||||
| @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { | |||||||
| 	} | 	} | ||||||
| 	token, err = CacheGetTokenByKey(key) | 	token, err = CacheGetTokenByKey(key) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
|  | 		if token.Status == common.TokenStatusExhausted { | ||||||
|  | 			return nil, errors.New("该令牌额度已用尽") | ||||||
|  | 		} else if token.Status == common.TokenStatusExpired { | ||||||
|  | 			return nil, errors.New("该令牌已过期") | ||||||
|  | 		} | ||||||
| 		if token.Status != common.TokenStatusEnabled { | 		if token.Status != common.TokenStatusEnabled { | ||||||
| 			return nil, errors.New("该令牌状态不可用") | 			return nil, errors.New("该令牌状态不可用") | ||||||
| 		} | 		} | ||||||
| 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { | ||||||
| 			token.Status = common.TokenStatusExpired | 			if !common.RedisEnabled { | ||||||
| 			err := token.SelectUpdate() | 				token.Status = common.TokenStatusExpired | ||||||
| 			if err != nil { | 				err := token.SelectUpdate() | ||||||
| 				common.SysError("failed to update token status" + err.Error()) | 				if err != nil { | ||||||
|  | 					common.SysError("failed to update token status" + err.Error()) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			return nil, errors.New("该令牌已过期") | 			return nil, errors.New("该令牌已过期") | ||||||
| 		} | 		} | ||||||
| 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | 		if !token.UnlimitedQuota && token.RemainQuota <= 0 { | ||||||
| 			token.Status = common.TokenStatusExhausted | 			if !common.RedisEnabled { | ||||||
| 			err := token.SelectUpdate() | 				// in this case, we can make sure the token is exhausted | ||||||
| 			if err != nil { | 				token.Status = common.TokenStatusExhausted | ||||||
| 				common.SysError("failed to update token status" + err.Error()) | 				err := token.SelectUpdate() | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to update token status" + err.Error()) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			return nil, errors.New("该令牌额度已用尽") | 			return nil, errors.New("该令牌额度已用尽") | ||||||
| 		} | 		} | ||||||
| 		go func() { |  | ||||||
| 			token.AccessedTime = common.GetTimestamp() |  | ||||||
| 			err := token.SelectUpdate() |  | ||||||
| 			if err != nil { |  | ||||||
| 				common.SysError("failed to update token" + err.Error()) |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 		return token, nil | 		return token, nil | ||||||
| 	} | 	} | ||||||
| 	return nil, errors.New("无效的令牌") | 	return nil, errors.New("无效的令牌") | ||||||
| @@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeTokenQuota, id, quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return increaseTokenQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func increaseTokenQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota": gorm.Expr("remain_quota + ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||||
| 			"used_quota":   gorm.Expr("used_quota - ?", quota), | 			"used_quota":    gorm.Expr("used_quota - ?", quota), | ||||||
|  | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
| @@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return decreaseTokenQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decreaseTokenQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"remain_quota": gorm.Expr("remain_quota - ?", quota), | 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||||
| 			"used_quota":   gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
|  | 			"accessed_time": common.GetTimestamp(), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	return err | 	return err | ||||||
|   | |||||||
| @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { | |||||||
| 	return user.Role >= common.RoleAdminUser | 	return user.Role >= common.RoleAdminUser | ||||||
| } | } | ||||||
|  |  | ||||||
| func IsUserEnabled(userId int) bool { | func IsUserEnabled(userId int) (bool, error) { | ||||||
| 	if userId == 0 { | 	if userId == 0 { | ||||||
| 		return false | 		return false, errors.New("user id is empty") | ||||||
| 	} | 	} | ||||||
| 	var user User | 	var user User | ||||||
| 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error | 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		common.SysError("no such user " + err.Error()) | 		return false, err | ||||||
| 		return false |  | ||||||
| 	} | 	} | ||||||
| 	return user.Status == common.UserStatusEnabled | 	return user.Status == common.UserStatusEnabled, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func ValidateAccessToken(token string) (user *User) { | func ValidateAccessToken(token string) (user *User) { | ||||||
| @@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUserQuota, id, quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return increaseUserQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func increaseUserQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { | |||||||
| 	if quota < 0 { | 	if quota < 0 { | ||||||
| 		return errors.New("quota 不能为负数!") | 		return errors.New("quota 不能为负数!") | ||||||
| 	} | 	} | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUserQuota, id, -quota) | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return decreaseUserQuota(id, quota) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decreaseUserQuota(id int, quota int) (err error) { | ||||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| @@ -293,10 +308,19 @@ func GetRootUserEmail() (email string) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||||
|  | 	if common.BatchUpdateEnabled { | ||||||
|  | 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||||
|  | 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
| 		map[string]interface{}{ | 		map[string]interface{}{ | ||||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||||
| 			"request_count": gorm.Expr("request_count + ?", 1), | 			"request_count": gorm.Expr("request_count + ?", count), | ||||||
| 		}, | 		}, | ||||||
| 	).Error | 	).Error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -304,6 +328,24 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func updateUserUsedQuota(id int, quota int) { | ||||||
|  | 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||||
|  | 		map[string]interface{}{ | ||||||
|  | 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||||
|  | 		}, | ||||||
|  | 	).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("failed to update user used quota: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func updateUserRequestCount(id int, count int) { | ||||||
|  | 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error | ||||||
|  | 	if err != nil { | ||||||
|  | 		common.SysError("failed to update user request count: " + err.Error()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func GetUsernameById(id int) (username string) { | func GetUsernameById(id int) (username string) { | ||||||
| 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | 	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) | ||||||
| 	return username | 	return username | ||||||
|   | |||||||
							
								
								
									
										77
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | |||||||
|  | package model | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"one-api/common" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	BatchUpdateTypeUserQuota = iota | ||||||
|  | 	BatchUpdateTypeTokenQuota | ||||||
|  | 	BatchUpdateTypeUsedQuota | ||||||
|  | 	BatchUpdateTypeChannelUsedQuota | ||||||
|  | 	BatchUpdateTypeRequestCount | ||||||
|  | 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var batchUpdateStores []map[int]int | ||||||
|  | var batchUpdateLocks []sync.Mutex | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
|  | 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | ||||||
|  | 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func InitBatchUpdater() { | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) | ||||||
|  | 			batchUpdate() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func addNewRecord(type_ int, id int, value int) { | ||||||
|  | 	batchUpdateLocks[type_].Lock() | ||||||
|  | 	defer batchUpdateLocks[type_].Unlock() | ||||||
|  | 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||||
|  | 		batchUpdateStores[type_][id] = value | ||||||
|  | 	} else { | ||||||
|  | 		batchUpdateStores[type_][id] += value | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func batchUpdate() { | ||||||
|  | 	common.SysLog("batch update started") | ||||||
|  | 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||||
|  | 		batchUpdateLocks[i].Lock() | ||||||
|  | 		store := batchUpdateStores[i] | ||||||
|  | 		batchUpdateStores[i] = make(map[int]int) | ||||||
|  | 		batchUpdateLocks[i].Unlock() | ||||||
|  | 		// TODO: maybe we can combine updates with same key? | ||||||
|  | 		for key, value := range store { | ||||||
|  | 			switch i { | ||||||
|  | 			case BatchUpdateTypeUserQuota: | ||||||
|  | 				err := increaseUserQuota(key, value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to batch update user quota: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 			case BatchUpdateTypeTokenQuota: | ||||||
|  | 				err := increaseTokenQuota(key, value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					common.SysError("failed to batch update token quota: " + err.Error()) | ||||||
|  | 				} | ||||||
|  | 			case BatchUpdateTypeUsedQuota: | ||||||
|  | 				updateUserUsedQuota(key, value) | ||||||
|  | 			case BatchUpdateTypeRequestCount: | ||||||
|  | 				updateUserRequestCount(key, value) | ||||||
|  | 			case BatchUpdateTypeChannelUsedQuota: | ||||||
|  | 				updateChannelUsedQuota(key, value) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	common.SysLog("batch update finished") | ||||||
|  | } | ||||||
| @@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | 		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) | ||||||
| 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | 		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) | ||||||
| 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | 		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) | ||||||
|  | 		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) | ||||||
| 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | 		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) | ||||||
| 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | 		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) | ||||||
| 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | 		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) | ||||||
| @@ -73,6 +74,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | ||||||
| 			channelRoute.POST("/", controller.AddChannel) | 			channelRoute.POST("/", controller.AddChannel) | ||||||
| 			channelRoute.PUT("/", controller.UpdateChannel) | 			channelRoute.PUT("/", controller.UpdateChannel) | ||||||
|  | 			channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) | ||||||
| 			channelRoute.DELETE("/:id", controller.DeleteChannel) | 			channelRoute.DELETE("/:id", controller.DeleteChannel) | ||||||
| 		} | 		} | ||||||
| 		tokenRoute := apiRouter.Group("/token") | 		tokenRoute := apiRouter.Group("/token") | ||||||
| @@ -97,6 +99,7 @@ func SetApiRouter(router *gin.Engine) { | |||||||
| 		} | 		} | ||||||
| 		logRoute := apiRouter.Group("/log") | 		logRoute := apiRouter.Group("/log") | ||||||
| 		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) | 		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) | ||||||
|  | 		logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) | ||||||
| 		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) | 		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) | ||||||
| 		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) | 		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) | ||||||
| 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) | 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func SetRelayRouter(router *gin.Engine) { | func SetRelayRouter(router *gin.Engine) { | ||||||
|  | 	router.Use(middleware.CORS()) | ||||||
| 	// https://platform.openai.com/docs/api-reference/introduction | 	// https://platform.openai.com/docs/api-reference/introduction | ||||||
| 	modelsRouter := router.Group("/v1/models") | 	modelsRouter := router.Group("/v1/models") | ||||||
| 	modelsRouter.Use(middleware.TokenAuth()) | 	modelsRouter.Use(middleware.TokenAuth()) | ||||||
|   | |||||||
| @@ -283,7 +283,9 @@ function App() { | |||||||
|           </Suspense> |           </Suspense> | ||||||
|         } |         } | ||||||
|       /> |       /> | ||||||
|       <Route path='*' element={NotFound} /> |       <Route path='*' element={ | ||||||
|  |           <NotFound /> | ||||||
|  |       } /> | ||||||
|     </Routes> |     </Routes> | ||||||
|   ); |   ); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; | import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; | ||||||
| import { Link } from 'react-router-dom'; | import { Link } from 'react-router-dom'; | ||||||
| import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; | ||||||
|  |  | ||||||
| import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; | ||||||
| import { renderGroup, renderNumber } from '../helpers/render'; | import { renderGroup, renderNumber } from '../helpers/render'; | ||||||
| @@ -24,7 +24,7 @@ function renderType(type) { | |||||||
|     } |     } | ||||||
|     type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; |     type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; | ||||||
|   } |   } | ||||||
|   return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>; |   return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>; | ||||||
| } | } | ||||||
|  |  | ||||||
| function renderBalance(type, balance) { | function renderBalance(type, balance) { | ||||||
| @@ -55,6 +55,7 @@ const ChannelsTable = () => { | |||||||
|   const [searchKeyword, setSearchKeyword] = useState(''); |   const [searchKeyword, setSearchKeyword] = useState(''); | ||||||
|   const [searching, setSearching] = useState(false); |   const [searching, setSearching] = useState(false); | ||||||
|   const [updatingBalance, setUpdatingBalance] = useState(false); |   const [updatingBalance, setUpdatingBalance] = useState(false); | ||||||
|  |   const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test")); | ||||||
|  |  | ||||||
|   const loadChannels = async (startIdx) => { |   const loadChannels = async (startIdx) => { | ||||||
|     const res = await API.get(`/api/channel/?p=${startIdx}`); |     const res = await API.get(`/api/channel/?p=${startIdx}`); | ||||||
| @@ -96,7 +97,7 @@ const ChannelsTable = () => { | |||||||
|       }); |       }); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   const manageChannel = async (id, action, idx) => { |   const manageChannel = async (id, action, idx, value) => { | ||||||
|     let data = { id }; |     let data = { id }; | ||||||
|     let res; |     let res; | ||||||
|     switch (action) { |     switch (action) { | ||||||
| @@ -111,6 +112,23 @@ const ChannelsTable = () => { | |||||||
|         data.status = 2; |         data.status = 2; | ||||||
|         res = await API.put('/api/channel/', data); |         res = await API.put('/api/channel/', data); | ||||||
|         break; |         break; | ||||||
|  |       case 'priority': | ||||||
|  |         if (value === '') { | ||||||
|  |           return; | ||||||
|  |         } | ||||||
|  |         data.priority = parseInt(value); | ||||||
|  |         res = await API.put('/api/channel/', data); | ||||||
|  |         break; | ||||||
|  |       case 'weight': | ||||||
|  |         if (value === '') { | ||||||
|  |           return; | ||||||
|  |         } | ||||||
|  |         data.weight = parseInt(value); | ||||||
|  |         if (data.weight < 0) { | ||||||
|  |           data.weight = 0; | ||||||
|  |         } | ||||||
|  |         res = await API.put('/api/channel/', data); | ||||||
|  |         break; | ||||||
|     } |     } | ||||||
|     const { success, message } = res.data; |     const { success, message } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
| @@ -135,9 +153,23 @@ const ChannelsTable = () => { | |||||||
|         return <Label basic color='green'>已启用</Label>; |         return <Label basic color='green'>已启用</Label>; | ||||||
|       case 2: |       case 2: | ||||||
|         return ( |         return ( | ||||||
|           <Label basic color='red'> |           <Popup | ||||||
|             已禁用 |             trigger={<Label basic color='red'> | ||||||
|           </Label> |               已禁用 | ||||||
|  |             </Label>} | ||||||
|  |             content='本渠道被手动禁用' | ||||||
|  |             basic | ||||||
|  |           /> | ||||||
|  |         ); | ||||||
|  |       case 3: | ||||||
|  |         return ( | ||||||
|  |           <Popup | ||||||
|  |             trigger={<Label basic color='yellow'> | ||||||
|  |               已禁用 | ||||||
|  |             </Label>} | ||||||
|  |             content='本渠道被程序自动禁用' | ||||||
|  |             basic | ||||||
|  |           /> | ||||||
|         ); |         ); | ||||||
|       default: |       default: | ||||||
|         return ( |         return ( | ||||||
| @@ -208,6 +240,17 @@ const ChannelsTable = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const deleteAllDisabledChannels = async () => { | ||||||
|  |     const res = await API.delete(`/api/channel/disabled`); | ||||||
|  |     const { success, message, data } = res.data; | ||||||
|  |     if (success) { | ||||||
|  |       showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); | ||||||
|  |       await refresh(); | ||||||
|  |     } else { | ||||||
|  |       showError(message); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   const updateChannelBalance = async (id, name, idx) => { |   const updateChannelBalance = async (id, name, idx) => { | ||||||
|     const res = await API.get(`/api/channel/update_balance/${id}/`); |     const res = await API.get(`/api/channel/update_balance/${id}/`); | ||||||
|     const { success, message, balance } = res.data; |     const { success, message, balance } = res.data; | ||||||
| @@ -274,7 +317,19 @@ const ChannelsTable = () => { | |||||||
|           onChange={handleKeywordChange} |           onChange={handleKeywordChange} | ||||||
|         /> |         /> | ||||||
|       </Form> |       </Form> | ||||||
|  |       { | ||||||
|  |         showPrompt && ( | ||||||
|  |           <Message onDismiss={() => { | ||||||
|  |             setShowPrompt(false); | ||||||
|  |             setPromptShown("channel-test"); | ||||||
|  |           }}> | ||||||
|  |             当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo | ||||||
|  |             模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 | ||||||
|  |  | ||||||
|  |             另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 | ||||||
|  |           </Message> | ||||||
|  |         ) | ||||||
|  |       } | ||||||
|       <Table basic compact size='small'> |       <Table basic compact size='small'> | ||||||
|         <Table.Header> |         <Table.Header> | ||||||
|           <Table.Row> |           <Table.Row> | ||||||
| @@ -334,6 +389,14 @@ const ChannelsTable = () => { | |||||||
|             > |             > | ||||||
|               余额 |               余额 | ||||||
|             </Table.HeaderCell> |             </Table.HeaderCell> | ||||||
|  |             <Table.HeaderCell | ||||||
|  |               style={{ cursor: 'pointer' }} | ||||||
|  |               onClick={() => { | ||||||
|  |                 sortChannel('priority'); | ||||||
|  |               }} | ||||||
|  |             > | ||||||
|  |               优先级 | ||||||
|  |             </Table.HeaderCell> | ||||||
|             <Table.HeaderCell>操作</Table.HeaderCell> |             <Table.HeaderCell>操作</Table.HeaderCell> | ||||||
|           </Table.Row> |           </Table.Row> | ||||||
|         </Table.Header> |         </Table.Header> | ||||||
| @@ -372,6 +435,22 @@ const ChannelsTable = () => { | |||||||
|                       basic |                       basic | ||||||
|                     /> |                     /> | ||||||
|                   </Table.Cell> |                   </Table.Cell> | ||||||
|  |                   <Table.Cell> | ||||||
|  |                     <Popup | ||||||
|  |                       trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => { | ||||||
|  |                         manageChannel( | ||||||
|  |                           channel.id, | ||||||
|  |                           'priority', | ||||||
|  |                           idx, | ||||||
|  |                           event.target.value | ||||||
|  |                         ); | ||||||
|  |                       }}> | ||||||
|  |                         <input style={{ maxWidth: '60px' }} /> | ||||||
|  |                       </Input>} | ||||||
|  |                       content='渠道选择优先级,越高越优先' | ||||||
|  |                       basic | ||||||
|  |                     /> | ||||||
|  |                   </Table.Cell> | ||||||
|                   <Table.Cell> |                   <Table.Cell> | ||||||
|                     <div> |                     <div> | ||||||
|                       <Button |                       <Button | ||||||
| @@ -440,7 +519,7 @@ const ChannelsTable = () => { | |||||||
|  |  | ||||||
|         <Table.Footer> |         <Table.Footer> | ||||||
|           <Table.Row> |           <Table.Row> | ||||||
|             <Table.HeaderCell colSpan='8'> |             <Table.HeaderCell colSpan='9'> | ||||||
|               <Button size='small' as={Link} to='/channel/add' loading={loading}> |               <Button size='small' as={Link} to='/channel/add' loading={loading}> | ||||||
|                 添加新的渠道 |                 添加新的渠道 | ||||||
|               </Button> |               </Button> | ||||||
| @@ -449,6 +528,20 @@ const ChannelsTable = () => { | |||||||
|               </Button> |               </Button> | ||||||
|               <Button size='small' onClick={updateAllChannelsBalance} |               <Button size='small' onClick={updateAllChannelsBalance} | ||||||
|                       loading={loading || updatingBalance}>更新所有已启用通道余额</Button> |                       loading={loading || updatingBalance}>更新所有已启用通道余额</Button> | ||||||
|  |               <Popup | ||||||
|  |                 trigger={ | ||||||
|  |                   <Button size='small' loading={loading}> | ||||||
|  |                     删除禁用渠道 | ||||||
|  |                   </Button> | ||||||
|  |                 } | ||||||
|  |                 on='click' | ||||||
|  |                 flowing | ||||||
|  |                 hoverable | ||||||
|  |               > | ||||||
|  |                 <Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}> | ||||||
|  |                   确认删除 | ||||||
|  |                 </Button> | ||||||
|  |               </Popup> | ||||||
|               <Pagination |               <Pagination | ||||||
|                 floated='right' |                 floated='right' | ||||||
|                 activePage={activePage} |                 activePage={activePage} | ||||||
|   | |||||||
| @@ -13,8 +13,8 @@ const GitHubOAuth = () => { | |||||||
|  |  | ||||||
|   let navigate = useNavigate(); |   let navigate = useNavigate(); | ||||||
|  |  | ||||||
|   const sendCode = async (code, count) => { |   const sendCode = async (code, state, count) => { | ||||||
|     const res = await API.get(`/api/oauth/github?code=${code}`); |     const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); | ||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       if (message === 'bind') { |       if (message === 'bind') { | ||||||
| @@ -36,13 +36,14 @@ const GitHubOAuth = () => { | |||||||
|       count++; |       count++; | ||||||
|       setPrompt(`出现错误,第 ${count} 次重试中...`); |       setPrompt(`出现错误,第 ${count} 次重试中...`); | ||||||
|       await new Promise((resolve) => setTimeout(resolve, count * 2000)); |       await new Promise((resolve) => setTimeout(resolve, count * 2000)); | ||||||
|       await sendCode(code, count); |       await sendCode(code, state, count); | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     let code = searchParams.get('code'); |     let code = searchParams.get('code'); | ||||||
|     sendCode(code, 0).then(); |     let state = searchParams.get('state'); | ||||||
|  |     sendCode(code, state, 0).then(); | ||||||
|   }, []); |   }, []); | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|   | |||||||
| @@ -2,7 +2,8 @@ import React, { useContext, useEffect, useState } from 'react'; | |||||||
| import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; | ||||||
| import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | import { Link, useNavigate, useSearchParams } from 'react-router-dom'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
| import { API, getLogo, showError, showSuccess } from '../helpers'; | import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; | ||||||
|  | import { onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const LoginForm = () => { | const LoginForm = () => { | ||||||
|   const [inputs, setInputs] = useState({ |   const [inputs, setInputs] = useState({ | ||||||
| @@ -31,12 +32,6 @@ const LoginForm = () => { | |||||||
|  |  | ||||||
|   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); |   const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); | ||||||
|  |  | ||||||
|   const onGitHubOAuthClicked = () => { |  | ||||||
|     window.open( |  | ||||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` |  | ||||||
|     ); |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   const onWeChatLoginClicked = () => { |   const onWeChatLoginClicked = () => { | ||||||
|     setShowWeChatLoginModal(true); |     setShowWeChatLoginModal(true); | ||||||
|   }; |   }; | ||||||
| @@ -73,8 +68,14 @@ const LoginForm = () => { | |||||||
|       if (success) { |       if (success) { | ||||||
|         userDispatch({ type: 'login', payload: data }); |         userDispatch({ type: 'login', payload: data }); | ||||||
|         localStorage.setItem('user', JSON.stringify(data)); |         localStorage.setItem('user', JSON.stringify(data)); | ||||||
|         navigate('/'); |         if (username === 'root' && password === '123456') { | ||||||
|         showSuccess('登录成功!'); |           navigate('/user/edit'); | ||||||
|  |           showSuccess('登录成功!'); | ||||||
|  |           showWarning('请立刻修改默认密码!'); | ||||||
|  |         } else { | ||||||
|  |           navigate('/token'); | ||||||
|  |           showSuccess('登录成功!'); | ||||||
|  |         } | ||||||
|       } else { |       } else { | ||||||
|         showError(message); |         showError(message); | ||||||
|       } |       } | ||||||
| @@ -131,7 +132,7 @@ const LoginForm = () => { | |||||||
|                 circular |                 circular | ||||||
|                 color='black' |                 color='black' | ||||||
|                 icon='github' |                 icon='github' | ||||||
|                 onClick={onGitHubOAuthClicked} |                 onClick={() => onGitHubOAuthClicked(status.github_client_id)} | ||||||
|               /> |               /> | ||||||
|             ) : ( |             ) : ( | ||||||
|               <></> |               <></> | ||||||
|   | |||||||
| @@ -56,9 +56,10 @@ const LogsTable = () => { | |||||||
|     token_name: '', |     token_name: '', | ||||||
|     model_name: '', |     model_name: '', | ||||||
|     start_timestamp: timestamp2string(0), |     start_timestamp: timestamp2string(0), | ||||||
|     end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) |     end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), | ||||||
|  |     channel: '' | ||||||
|   }); |   }); | ||||||
|   const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; |   const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; | ||||||
|  |  | ||||||
|   const [stat, setStat] = useState({ |   const [stat, setStat] = useState({ | ||||||
|     quota: 0, |     quota: 0, | ||||||
| @@ -84,7 +85,7 @@ const LogsTable = () => { | |||||||
|   const getLogStat = async () => { |   const getLogStat = async () => { | ||||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; |     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; |     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||||
|     let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); |     let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); | ||||||
|     const { success, message, data } = res.data; |     const { success, message, data } = res.data; | ||||||
|     if (success) { |     if (success) { | ||||||
|       setStat(data); |       setStat(data); | ||||||
| @@ -109,7 +110,7 @@ const LogsTable = () => { | |||||||
|     let localStartTimestamp = Date.parse(start_timestamp) / 1000; |     let localStartTimestamp = Date.parse(start_timestamp) / 1000; | ||||||
|     let localEndTimestamp = Date.parse(end_timestamp) / 1000; |     let localEndTimestamp = Date.parse(end_timestamp) / 1000; | ||||||
|     if (isAdminUser) { |     if (isAdminUser) { | ||||||
|       url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; |       url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; | ||||||
|     } else { |     } else { | ||||||
|       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; |       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; | ||||||
|     } |     } | ||||||
| @@ -205,16 +206,9 @@ const LogsTable = () => { | |||||||
|         </Header> |         </Header> | ||||||
|         <Form> |         <Form> | ||||||
|           <Form.Group> |           <Form.Group> | ||||||
|             { |             <Form.Input fluid label={'令牌名称'} width={3} value={token_name} | ||||||
|               isAdminUser && ( |  | ||||||
|                 <Form.Input fluid label={'用户名称'} width={2} value={username} |  | ||||||
|                             placeholder={'可选值'} name='username' |  | ||||||
|                             onChange={handleInputChange} /> |  | ||||||
|               ) |  | ||||||
|             } |  | ||||||
|             <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name} |  | ||||||
|                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> |                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> | ||||||
|             <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值' |             <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值' | ||||||
|                         name='model_name' |                         name='model_name' | ||||||
|                         onChange={handleInputChange} /> |                         onChange={handleInputChange} /> | ||||||
|             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' |             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' | ||||||
| @@ -225,6 +219,19 @@ const LogsTable = () => { | |||||||
|                         onChange={handleInputChange} /> |                         onChange={handleInputChange} /> | ||||||
|             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> |             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> | ||||||
|           </Form.Group> |           </Form.Group> | ||||||
|  |           { | ||||||
|  |             isAdminUser && <> | ||||||
|  |               <Form.Group> | ||||||
|  |                 <Form.Input fluid label={'渠道 ID'} width={3} value={channel} | ||||||
|  |                             placeholder='可选值' name='channel' | ||||||
|  |                             onChange={handleInputChange} /> | ||||||
|  |                 <Form.Input fluid label={'用户名称'} width={3} value={username} | ||||||
|  |                             placeholder={'可选值'} name='username' | ||||||
|  |                             onChange={handleInputChange} /> | ||||||
|  |  | ||||||
|  |               </Form.Group> | ||||||
|  |             </> | ||||||
|  |           } | ||||||
|         </Form> |         </Form> | ||||||
|         <Table basic compact size='small'> |         <Table basic compact size='small'> | ||||||
|           <Table.Header> |           <Table.Header> | ||||||
| @@ -238,6 +245,17 @@ const LogsTable = () => { | |||||||
|               > |               > | ||||||
|                 时间 |                 时间 | ||||||
|               </Table.HeaderCell> |               </Table.HeaderCell> | ||||||
|  |               { | ||||||
|  |                 isAdminUser && <Table.HeaderCell | ||||||
|  |                   style={{ cursor: 'pointer' }} | ||||||
|  |                   onClick={() => { | ||||||
|  |                     sortLog('channel'); | ||||||
|  |                   }} | ||||||
|  |                   width={1} | ||||||
|  |                 > | ||||||
|  |                   渠道 | ||||||
|  |                 </Table.HeaderCell> | ||||||
|  |               } | ||||||
|               { |               { | ||||||
|                 isAdminUser && <Table.HeaderCell |                 isAdminUser && <Table.HeaderCell | ||||||
|                   style={{ cursor: 'pointer' }} |                   style={{ cursor: 'pointer' }} | ||||||
| @@ -299,16 +317,16 @@ const LogsTable = () => { | |||||||
|                 onClick={() => { |                 onClick={() => { | ||||||
|                   sortLog('quota'); |                   sortLog('quota'); | ||||||
|                 }} |                 }} | ||||||
|                 width={2} |                 width={1} | ||||||
|               > |               > | ||||||
|                 消耗额度 |                 额度 | ||||||
|               </Table.HeaderCell> |               </Table.HeaderCell> | ||||||
|               <Table.HeaderCell |               <Table.HeaderCell | ||||||
|                 style={{ cursor: 'pointer' }} |                 style={{ cursor: 'pointer' }} | ||||||
|                 onClick={() => { |                 onClick={() => { | ||||||
|                   sortLog('content'); |                   sortLog('content'); | ||||||
|                 }} |                 }} | ||||||
|                 width={isAdminUser ? 4 : 5} |                 width={isAdminUser ? 4 : 6} | ||||||
|               > |               > | ||||||
|                 详情 |                 详情 | ||||||
|               </Table.HeaderCell> |               </Table.HeaderCell> | ||||||
| @@ -324,8 +342,13 @@ const LogsTable = () => { | |||||||
|               .map((log, idx) => { |               .map((log, idx) => { | ||||||
|                 if (log.deleted) return <></>; |                 if (log.deleted) return <></>; | ||||||
|                 return ( |                 return ( | ||||||
|                   <Table.Row key={log.created_at}> |                   <Table.Row key={log.id}> | ||||||
|                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> |                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> | ||||||
|  |                     { | ||||||
|  |                       isAdminUser && ( | ||||||
|  |                         <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell> | ||||||
|  |                       ) | ||||||
|  |                     } | ||||||
|                     { |                     { | ||||||
|                       isAdminUser && ( |                       isAdminUser && ( | ||||||
|                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> |                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> | ||||||
| @@ -345,7 +368,7 @@ const LogsTable = () => { | |||||||
|  |  | ||||||
|           <Table.Footer> |           <Table.Footer> | ||||||
|             <Table.Row> |             <Table.Row> | ||||||
|               <Table.HeaderCell colSpan={'9'}> |               <Table.HeaderCell colSpan={'10'}> | ||||||
|                 <Select |                 <Select | ||||||
|                   placeholder='选择明细分类' |                   placeholder='选择明细分类' | ||||||
|                   options={LOG_OPTIONS} |                   options={LOG_OPTIONS} | ||||||
|   | |||||||
| @@ -1,8 +1,9 @@ | |||||||
| import React, { useEffect, useState } from 'react'; | import React, { useEffect, useState } from 'react'; | ||||||
| import { Divider, Form, Grid, Header } from 'semantic-ui-react'; | import { Divider, Form, Grid, Header } from 'semantic-ui-react'; | ||||||
| import { API, showError, verifyJSON } from '../helpers'; | import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; | ||||||
|  |  | ||||||
| const OperationSetting = () => { | const OperationSetting = () => { | ||||||
|  |   let now = new Date(); | ||||||
|   let [inputs, setInputs] = useState({ |   let [inputs, setInputs] = useState({ | ||||||
|     QuotaForNewUser: 0, |     QuotaForNewUser: 0, | ||||||
|     QuotaForInviter: 0, |     QuotaForInviter: 0, | ||||||
| @@ -20,10 +21,11 @@ const OperationSetting = () => { | |||||||
|     DisplayInCurrencyEnabled: '', |     DisplayInCurrencyEnabled: '', | ||||||
|     DisplayTokenStatEnabled: '', |     DisplayTokenStatEnabled: '', | ||||||
|     ApproximateTokenEnabled: '', |     ApproximateTokenEnabled: '', | ||||||
|     RetryTimes: 0, |     RetryTimes: 0 | ||||||
|   }); |   }); | ||||||
|   const [originInputs, setOriginInputs] = useState({}); |   const [originInputs, setOriginInputs] = useState({}); | ||||||
|   let [loading, setLoading] = useState(false); |   let [loading, setLoading] = useState(false); | ||||||
|  |   let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago | ||||||
|  |  | ||||||
|   const getOptions = async () => { |   const getOptions = async () => { | ||||||
|     const res = await API.get('/api/option/'); |     const res = await API.get('/api/option/'); | ||||||
| @@ -130,6 +132,17 @@ const OperationSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const deleteHistoryLogs = async () => { | ||||||
|  |     console.log(inputs); | ||||||
|  |     const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); | ||||||
|  |     const { success, message, data } = res.data; | ||||||
|  |     if (success) { | ||||||
|  |       showSuccess(`${data} 条日志已清理!`); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|  |     showError('日志清理失败:' + message); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Grid columns={1}> |     <Grid columns={1}> | ||||||
|       <Grid.Column> |       <Grid.Column> | ||||||
| @@ -179,12 +192,6 @@ const OperationSetting = () => { | |||||||
|             /> |             /> | ||||||
|           </Form.Group> |           </Form.Group> | ||||||
|           <Form.Group inline> |           <Form.Group inline> | ||||||
|             <Form.Checkbox |  | ||||||
|               checked={inputs.LogConsumeEnabled === 'true'} |  | ||||||
|               label='启用额度消费日志记录' |  | ||||||
|               name='LogConsumeEnabled' |  | ||||||
|               onChange={handleInputChange} |  | ||||||
|             /> |  | ||||||
|             <Form.Checkbox |             <Form.Checkbox | ||||||
|               checked={inputs.DisplayInCurrencyEnabled === 'true'} |               checked={inputs.DisplayInCurrencyEnabled === 'true'} | ||||||
|               label='以货币形式显示额度' |               label='以货币形式显示额度' | ||||||
| @@ -208,6 +215,28 @@ const OperationSetting = () => { | |||||||
|             submitConfig('general').then(); |             submitConfig('general').then(); | ||||||
|           }}>保存通用设置</Form.Button> |           }}>保存通用设置</Form.Button> | ||||||
|           <Divider /> |           <Divider /> | ||||||
|  |           <Header as='h3'> | ||||||
|  |             日志设置 | ||||||
|  |           </Header> | ||||||
|  |           <Form.Group inline> | ||||||
|  |             <Form.Checkbox | ||||||
|  |               checked={inputs.LogConsumeEnabled === 'true'} | ||||||
|  |               label='启用额度消费日志记录' | ||||||
|  |               name='LogConsumeEnabled' | ||||||
|  |               onChange={handleInputChange} | ||||||
|  |             /> | ||||||
|  |           </Form.Group> | ||||||
|  |           <Form.Group widths={4}> | ||||||
|  |             <Form.Input label='目标时间' value={historyTimestamp} type='datetime-local' | ||||||
|  |                         name='history_timestamp' | ||||||
|  |                         onChange={(e, { name, value }) => { | ||||||
|  |                           setHistoryTimestamp(value); | ||||||
|  |                         }} /> | ||||||
|  |           </Form.Group> | ||||||
|  |           <Form.Button onClick={() => { | ||||||
|  |             deleteHistoryLogs().then(); | ||||||
|  |           }}>清理历史日志</Form.Button> | ||||||
|  |           <Divider /> | ||||||
|           <Header as='h3'> |           <Header as='h3'> | ||||||
|             监控设置 |             监控设置 | ||||||
|           </Header> |           </Header> | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom'; | |||||||
| import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; | ||||||
| import Turnstile from 'react-turnstile'; | import Turnstile from 'react-turnstile'; | ||||||
| import { UserContext } from '../context/User'; | import { UserContext } from '../context/User'; | ||||||
|  | import { onGitHubOAuthClicked } from './utils'; | ||||||
|  |  | ||||||
| const PersonalSetting = () => { | const PersonalSetting = () => { | ||||||
|   const [userState, userDispatch] = useContext(UserContext); |   const [userState, userDispatch] = useContext(UserContext); | ||||||
| @@ -130,12 +131,6 @@ const PersonalSetting = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   const openGitHubOAuth = () => { |  | ||||||
|     window.open( |  | ||||||
|       `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` |  | ||||||
|     ); |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   const sendVerificationCode = async () => { |   const sendVerificationCode = async () => { | ||||||
|     setDisableButton(true); |     setDisableButton(true); | ||||||
|     if (inputs.email === '') return; |     if (inputs.email === '') return; | ||||||
| @@ -249,7 +244,7 @@ const PersonalSetting = () => { | |||||||
|       </Modal> |       </Modal> | ||||||
|       { |       { | ||||||
|         status.github_oauth && ( |         status.github_oauth && ( | ||||||
|           <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button> |           <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button> | ||||||
|         ) |         ) | ||||||
|       } |       } | ||||||
|       <Button |       <Button | ||||||
|   | |||||||
| @@ -96,7 +96,7 @@ const TokensTable = () => { | |||||||
|     let nextUrl; |     let nextUrl; | ||||||
|    |    | ||||||
|     if (nextLink) { |     if (nextLink) { | ||||||
|       nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`; |       nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|     } else { |     } else { | ||||||
|       nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; |       nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|     } |     } | ||||||
| @@ -138,7 +138,7 @@ const TokensTable = () => { | |||||||
|     let defaultUrl; |     let defaultUrl; | ||||||
|    |    | ||||||
|     if (chatLink) { |     if (chatLink) { | ||||||
|       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; |       defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|     } else { |     } else { | ||||||
|       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; |       defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; | ||||||
|     } |     } | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | import { API, showError } from '../helpers'; | ||||||
|  |  | ||||||
|  | export async function getOAuthState() { | ||||||
|  |   const res = await API.get('/api/oauth/state'); | ||||||
|  |   const { success, message, data } = res.data; | ||||||
|  |   if (success) { | ||||||
|  |     return data; | ||||||
|  |   } else { | ||||||
|  |     showError(message); | ||||||
|  |     return ''; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export async function onGitHubOAuthClicked(github_client_id) { | ||||||
|  |   const state = await getOAuthState(); | ||||||
|  |   if (!state) return; | ||||||
|  |   window.open( | ||||||
|  |     `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` | ||||||
|  |   ); | ||||||
|  | } | ||||||
| @@ -8,7 +8,10 @@ export const CHANNEL_OPTIONS = [ | |||||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, |   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, |   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||||
|   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, |   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||||
|  |   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, |   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||||
|  |   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||||
|  |   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||||
|   { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, |   { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, | ||||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, |   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||||
|   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, |   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, | ||||||
|   | |||||||
| @@ -186,4 +186,14 @@ export const verifyJSON = (str) => { | |||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|   return true; |   return true; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | export function shouldShowPrompt(id) { | ||||||
|  |   let prompt = localStorage.getItem(`prompt-${id}`); | ||||||
|  |   return !prompt; | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export function setPromptShown(id) { | ||||||
|  |   localStorage.setItem(`prompt-${id}`, 'true'); | ||||||
|  | } | ||||||
| @@ -10,6 +10,22 @@ const MODEL_MAPPING_EXAMPLE = { | |||||||
|   'gpt-4-32k-0314': 'gpt-4-32k' |   'gpt-4-32k-0314': 'gpt-4-32k' | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | function type2secretPrompt(type) { | ||||||
|  |   // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') | ||||||
|  |   switch (type) { | ||||||
|  |     case 15: | ||||||
|  |       return '按照如下格式输入:APIKey|SecretKey'; | ||||||
|  |     case 18: | ||||||
|  |       return '按照如下格式输入:APPID|APISecret|APIKey'; | ||||||
|  |     case 22: | ||||||
|  |       return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; | ||||||
|  |     case 23: | ||||||
|  |       return '按照如下格式输入:AppId|SecretId|SecretKey'; | ||||||
|  |     default: | ||||||
|  |       return '请输入渠道对应的鉴权密钥'; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| const EditChannel = () => { | const EditChannel = () => { | ||||||
|   const params = useParams(); |   const params = useParams(); | ||||||
|   const navigate = useNavigate(); |   const navigate = useNavigate(); | ||||||
| @@ -53,7 +69,7 @@ const EditChannel = () => { | |||||||
|           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; |           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; | ||||||
|           break; |           break; | ||||||
|         case 17: |         case 17: | ||||||
|           localModels = ['qwen-v1', 'qwen-plus-v1']; |           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; | ||||||
|           break; |           break; | ||||||
|         case 16: |         case 16: | ||||||
|           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; |           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; | ||||||
| @@ -62,7 +78,10 @@ const EditChannel = () => { | |||||||
|           localModels = ['SparkDesk']; |           localModels = ['SparkDesk']; | ||||||
|           break; |           break; | ||||||
|         case 19: |         case 19: | ||||||
|           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; |           localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; | ||||||
|  |           break; | ||||||
|  |         case 23: | ||||||
|  |           localModels = ['hunyuan']; | ||||||
|           break; |           break; | ||||||
|       } |       } | ||||||
|       setInputs((inputs) => ({ ...inputs, models: localModels })); |       setInputs((inputs) => ({ ...inputs, models: localModels })); | ||||||
| @@ -160,7 +179,7 @@ const EditChannel = () => { | |||||||
|       return; |       return; | ||||||
|     } |     } | ||||||
|     let localInputs = inputs; |     let localInputs = inputs; | ||||||
|     if (localInputs.base_url.endsWith('/')) { |     if (localInputs.base_url && localInputs.base_url.endsWith('/')) { | ||||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); |       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||||
|     } |     } | ||||||
|     if (localInputs.type === 3 && localInputs.other === '') { |     if (localInputs.type === 3 && localInputs.other === '') { | ||||||
| @@ -169,9 +188,6 @@ const EditChannel = () => { | |||||||
|     if (localInputs.type === 18 && localInputs.other === '') { |     if (localInputs.type === 18 && localInputs.other === '') { | ||||||
|       localInputs.other = 'v2.1'; |       localInputs.other = 'v2.1'; | ||||||
|     } |     } | ||||||
|     if (localInputs.model_mapping === '') { |  | ||||||
|       localInputs.model_mapping = '{}'; |  | ||||||
|     } |  | ||||||
|     let res; |     let res; | ||||||
|     localInputs.models = localInputs.models.join(','); |     localInputs.models = localInputs.models.join(','); | ||||||
|     localInputs.group = localInputs.groups.join(','); |     localInputs.group = localInputs.groups.join(','); | ||||||
| @@ -193,6 +209,24 @@ const EditChannel = () => { | |||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const addCustomModel = () => { | ||||||
|  |     if (customModel.trim() === '') return; | ||||||
|  |     if (inputs.models.includes(customModel)) return; | ||||||
|  |     let localModels = [...inputs.models]; | ||||||
|  |     localModels.push(customModel); | ||||||
|  |     let localModelOptions = []; | ||||||
|  |     localModelOptions.push({ | ||||||
|  |       key: customModel, | ||||||
|  |       text: customModel, | ||||||
|  |       value: customModel | ||||||
|  |     }); | ||||||
|  |     setModelOptions(modelOptions => { | ||||||
|  |       return [...modelOptions, ...localModelOptions]; | ||||||
|  |     }); | ||||||
|  |     setCustomModel(''); | ||||||
|  |     handleInputChange(null, { name: 'models', value: localModels }); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <> |     <> | ||||||
|       <Segment loading={loading}> |       <Segment loading={loading}> | ||||||
| @@ -295,6 +329,20 @@ const EditChannel = () => { | |||||||
|               </Form.Field> |               </Form.Field> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|  |           { | ||||||
|  |             inputs.type === 21 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='知识库 ID' | ||||||
|  |                   name='other' | ||||||
|  |                   placeholder={'请输入知识库 ID,例如:123456'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.other} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
|             <Form.Dropdown |             <Form.Dropdown | ||||||
|               label='模型' |               label='模型' | ||||||
| @@ -322,29 +370,19 @@ const EditChannel = () => { | |||||||
|             }}>清除所有模型</Button> |             }}>清除所有模型</Button> | ||||||
|             <Input |             <Input | ||||||
|               action={ |               action={ | ||||||
|                 <Button type={'button'} onClick={() => { |                 <Button type={'button'} onClick={addCustomModel}>填入</Button> | ||||||
|                   if (customModel.trim() === '') return; |  | ||||||
|                   if (inputs.models.includes(customModel)) return; |  | ||||||
|                   let localModels = [...inputs.models]; |  | ||||||
|                   localModels.push(customModel); |  | ||||||
|                   let localModelOptions = []; |  | ||||||
|                   localModelOptions.push({ |  | ||||||
|                     key: customModel, |  | ||||||
|                     text: customModel, |  | ||||||
|                     value: customModel |  | ||||||
|                   }); |  | ||||||
|                   setModelOptions(modelOptions => { |  | ||||||
|                     return [...modelOptions, ...localModelOptions]; |  | ||||||
|                   }); |  | ||||||
|                   setCustomModel(''); |  | ||||||
|                   handleInputChange(null, { name: 'models', value: localModels }); |  | ||||||
|                 }}>填入</Button> |  | ||||||
|               } |               } | ||||||
|               placeholder='输入自定义模型名称' |               placeholder='输入自定义模型名称' | ||||||
|               value={customModel} |               value={customModel} | ||||||
|               onChange={(e, { value }) => { |               onChange={(e, { value }) => { | ||||||
|                 setCustomModel(value); |                 setCustomModel(value); | ||||||
|               }} |               }} | ||||||
|  |               onKeyDown={(e) => { | ||||||
|  |                 if (e.key === 'Enter') { | ||||||
|  |                   addCustomModel(); | ||||||
|  |                   e.preventDefault(); | ||||||
|  |                 } | ||||||
|  |               }} | ||||||
|             /> |             /> | ||||||
|           </div> |           </div> | ||||||
|           <Form.Field> |           <Form.Field> | ||||||
| @@ -375,7 +413,7 @@ const EditChannel = () => { | |||||||
|                 label='密钥' |                 label='密钥' | ||||||
|                 name='key' |                 name='key' | ||||||
|                 required |                 required | ||||||
|                 placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')} |                 placeholder={type2secretPrompt(inputs.type)} | ||||||
|                 onChange={handleInputChange} |                 onChange={handleInputChange} | ||||||
|                 value={inputs.key} |                 value={inputs.key} | ||||||
|                 autoComplete='new-password' |                 autoComplete='new-password' | ||||||
| @@ -393,7 +431,7 @@ const EditChannel = () => { | |||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|           { |           { | ||||||
|             inputs.type !== 3 && inputs.type !== 8 && ( |             inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( | ||||||
|               <Form.Field> |               <Form.Field> | ||||||
|                 <Form.Input |                 <Form.Input | ||||||
|                   label='代理' |                   label='代理' | ||||||
| @@ -406,6 +444,20 @@ const EditChannel = () => { | |||||||
|               </Form.Field> |               </Form.Field> | ||||||
|             ) |             ) | ||||||
|           } |           } | ||||||
|  |           { | ||||||
|  |             inputs.type === 22 && ( | ||||||
|  |               <Form.Field> | ||||||
|  |                 <Form.Input | ||||||
|  |                   label='私有部署地址' | ||||||
|  |                   name='base_url' | ||||||
|  |                   placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'} | ||||||
|  |                   onChange={handleInputChange} | ||||||
|  |                   value={inputs.base_url} | ||||||
|  |                   autoComplete='new-password' | ||||||
|  |                 /> | ||||||
|  |               </Form.Field> | ||||||
|  |             ) | ||||||
|  |           } | ||||||
|           <Button onClick={handleCancel}>取消</Button> |           <Button onClick={handleCancel}>取消</Button> | ||||||
|           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> |           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button> | ||||||
|         </Form> |         </Form> | ||||||
|   | |||||||
| @@ -1,19 +1,12 @@ | |||||||
| import React from 'react'; | import React from 'react'; | ||||||
| import { Segment, Header } from 'semantic-ui-react'; | import { Message } from 'semantic-ui-react'; | ||||||
|  |  | ||||||
| const NotFound = () => ( | const NotFound = () => ( | ||||||
|   <> |   <> | ||||||
|     <Header |     <Message negative> | ||||||
|       block |       <Message.Header>页面不存在</Message.Header> | ||||||
|       as="h4" |       <p>请检查你的浏览器地址是否正确</p> | ||||||
|       content="404" |     </Message> | ||||||
|       attached="top" |  | ||||||
|       icon="info" |  | ||||||
|       className="small-icon" |  | ||||||
|     /> |  | ||||||
|     <Segment attached="bottom"> |  | ||||||
|       未找到所请求的页面 |  | ||||||
|     </Segment> |  | ||||||
|   </> |   </> | ||||||
| ); | ); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -102,7 +102,7 @@ const EditUser = () => { | |||||||
|               label='密码' |               label='密码' | ||||||
|               name='password' |               name='password' | ||||||
|               type={'password'} |               type={'password'} | ||||||
|               placeholder={'请输入新的密码'} |               placeholder={'请输入新的密码,最短 8 位'} | ||||||
|               onChange={handleInputChange} |               onChange={handleInputChange} | ||||||
|               value={password} |               value={password} | ||||||
|               autoComplete='new-password' |               autoComplete='new-password' | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user