mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			47 Commits
		
	
	
		
			v0.5.4-alp
			...
			v0.5.6-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					f9b748c2ca | ||
| 
						 | 
					fd98463611 | ||
| 
						 | 
					f5a1cd3463 | ||
| 
						 | 
					8651451e53 | ||
| 
						 | 
					1c5bb97a42 | ||
| 
						 | 
					de868e4e4e | ||
| 
						 | 
					1d258cc898 | ||
| 
						 | 
					37e09d764c | ||
| 
						 | 
					159b9e3369 | ||
| 
						 | 
					92001986db | ||
| 
						 | 
					a5647b1ea7 | ||
| 
						 | 
					215e54fc96 | ||
| 
						 | 
					ecf8a6d875 | ||
| 
						 | 
					24df3e5f62 | ||
| 
						 | 
					12ef9679a7 | ||
| 
						 | 
					328aa68255 | ||
| 
						 | 
					4335f005a6 | ||
| 
						 | 
					fe26a1448d | ||
| 
						 | 
					42451d9d02 | ||
| 
						 | 
					25c4c111ab | ||
| 
						 | 
					0d50ad4b2b | ||
| 
						 | 
					959bcdef88 | ||
| 
						 | 
					39ae8075e4 | ||
| 
						 | 
					b57a0eca16 | ||
| 
						 | 
					1b4cc78890 | ||
| 
						 | 
					420c375140 | ||
| 
						 | 
					01863d3e44 | ||
| 
						 | 
					d0a0e871e1 | ||
| 
						 | 
					bd6fe1e93c | ||
| 
						 | 
					c55bb67818 | ||
| 
						 | 
					0f949c3782 | ||
| 
						 | 
					a721a5b6f9 | ||
| 
						 | 
					276163affd | ||
| 
						 | 
					621eb91b46 | ||
| 
						 | 
					7e575abb95 | ||
| 
						 | 
					9db93316c4 | ||
| 
						 | 
					c3dc315e75 | ||
| 
						 | 
					04acdb1ccb | ||
| 
						 | 
					f0d5e102a3 | ||
| 
						 | 
					abbf2fded0 | ||
| 
						 | 
					ef2c5abb5b | ||
| 
						 | 
					56b5007379 | ||
| 
						 | 
					d09d317459 | ||
| 
						 | 
					1c4409ae80 | ||
| 
						 | 
					5ee24e8acf | ||
| 
						 | 
					4f2f911e4d | ||
| 
						 | 
					fdb2cccf65 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -4,4 +4,5 @@ upload
 | 
				
			|||||||
*.exe
 | 
					*.exe
 | 
				
			||||||
*.db
 | 
					*.db
 | 
				
			||||||
build
 | 
					build
 | 
				
			||||||
*.db-journal
 | 
					*.db-journal
 | 
				
			||||||
 | 
					logs
 | 
				
			||||||
							
								
								
									
										48
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								README.md
									
									
									
									
									
								
							@@ -68,12 +68,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
				
			|||||||
   + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
 | 
					   + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
 | 
				
			||||||
   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
 | 
					   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
 | 
				
			||||||
   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
					   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
				
			||||||
 | 
					   + [x] [360 智脑](https://ai.360.cn)
 | 
				
			||||||
2. 支持配置镜像以及众多第三方代理服务:
 | 
					2. 支持配置镜像以及众多第三方代理服务:
 | 
				
			||||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
					   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
				
			||||||
 | 
					   + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
 | 
				
			||||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
					   + [x] [API2D](https://api2d.com/r/197971)
 | 
				
			||||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
					   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
				
			||||||
   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
					   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
				
			||||||
   + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
 | 
					 | 
				
			||||||
   + [x] 自定义渠道:例如各种未收录的第三方代理服务
 | 
					   + [x] 自定义渠道:例如各种未收录的第三方代理服务
 | 
				
			||||||
3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
					3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
				
			||||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
					4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
				
			||||||
@@ -108,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 | 
					数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
 | 
					如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
 | 
					如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
 | 
				
			||||||
@@ -208,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
 | 
					注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### QChatGPT - QQ机器人
 | 
				
			||||||
 | 
					项目主页:https://github.com/RockChinQ/QChatGPT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 部署到第三方平台
 | 
					### 部署到第三方平台
 | 
				
			||||||
<details>
 | 
					<details>
 | 
				
			||||||
<summary><strong>部署到 Sealos </strong></summary>
 | 
					<summary><strong>部署到 Sealos </strong></summary>
 | 
				
			||||||
@@ -259,6 +269,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
注意,具体的 API Base 的格式取决于你所使用的客户端。
 | 
					注意,具体的 API Base 的格式取决于你所使用的客户端。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					例如对于 OpenAI 的官方库:
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					OPENAI_API_KEY="sk-xxxxxx"
 | 
				
			||||||
 | 
					OPENAI_API_BASE="https://<HOST>:<PORT>/v1" 
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```mermaid
 | 
					```mermaid
 | 
				
			||||||
graph LR
 | 
					graph LR
 | 
				
			||||||
    A(用户)
 | 
					    A(用户)
 | 
				
			||||||
@@ -274,8 +290,9 @@ graph LR
 | 
				
			|||||||
不加的话将会使用负载均衡的方式使用多个渠道。
 | 
					不加的话将会使用负载均衡的方式使用多个渠道。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 环境变量
 | 
					### 环境变量
 | 
				
			||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。
 | 
					1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
 | 
				
			||||||
   + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
					   + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
				
			||||||
 | 
					   + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
 | 
				
			||||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
 | 
					2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
 | 
				
			||||||
   + 例子:`SESSION_SECRET=random_string`
 | 
					   + 例子:`SESSION_SECRET=random_string`
 | 
				
			||||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
 | 
					3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
 | 
				
			||||||
@@ -292,21 +309,31 @@ graph LR
 | 
				
			|||||||
     + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
 | 
					     + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
 | 
				
			||||||
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
 | 
					4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
 | 
				
			||||||
   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
					   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
				
			||||||
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
 | 
					5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
				
			||||||
 | 
					   + 例子:`MEMORY_CACHE_ENABLED=true`
 | 
				
			||||||
 | 
					6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
 | 
				
			||||||
   + 例子:`SYNC_FREQUENCY=60`
 | 
					   + 例子:`SYNC_FREQUENCY=60`
 | 
				
			||||||
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
 | 
					7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
 | 
				
			||||||
   + 例子:`NODE_TYPE=slave`
 | 
					   + 例子:`NODE_TYPE=slave`
 | 
				
			||||||
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
 | 
					8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
 | 
				
			||||||
   + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
 | 
					   + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
 | 
				
			||||||
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
 | 
					9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
 | 
				
			||||||
   + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 | 
					   + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 | 
				
			||||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
 | 
					10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
 | 
				
			||||||
   + 例子:`POLLING_INTERVAL=5`
 | 
					    + 例子:`POLLING_INTERVAL=5`
 | 
				
			||||||
 | 
					11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
				
			||||||
 | 
					    + 例子:`BATCH_UPDATE_ENABLED=true`
 | 
				
			||||||
 | 
					    + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
 | 
				
			||||||
 | 
					12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
 | 
				
			||||||
 | 
					    + 例子:`BATCH_UPDATE_INTERVAL=5`
 | 
				
			||||||
 | 
					13. 请求频率限制:
 | 
				
			||||||
 | 
					    + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
 | 
				
			||||||
 | 
					    + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 命令行参数
 | 
					### 命令行参数
 | 
				
			||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
					1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
				
			||||||
   + 例子:`--port 3000`
 | 
					   + 例子:`--port 3000`
 | 
				
			||||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。
 | 
					2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
 | 
				
			||||||
   + 例子:`--log-dir ./logs`
 | 
					   + 例子:`--log-dir ./logs`
 | 
				
			||||||
3. `--version`: 打印系统版本号并退出。
 | 
					3. `--version`: 打印系统版本号并退出。
 | 
				
			||||||
4. `--help`: 查看命令的使用帮助和参数说明。
 | 
					4. `--help`: 查看命令的使用帮助和参数说明。
 | 
				
			||||||
@@ -338,6 +365,7 @@ https://openai.justsong.cn
 | 
				
			|||||||
5. ChatGPT Next Web 报错:`Failed to fetch`
 | 
					5. ChatGPT Next Web 报错:`Failed to fetch`
 | 
				
			||||||
   + 部署的时候不要设置 `BASE_URL`。
 | 
					   + 部署的时候不要设置 `BASE_URL`。
 | 
				
			||||||
   + 检查你的接口地址和 API Key 有没有填对。
 | 
					   + 检查你的接口地址和 API Key 有没有填对。
 | 
				
			||||||
 | 
					   + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
 | 
				
			||||||
6. 报错:`当前分组负载已饱和,请稍后再试`
 | 
					6. 报错:`当前分组负载已饱和,请稍后再试`
 | 
				
			||||||
   + 上游通道 429 了。
 | 
					   + 上游通道 429 了。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -351,4 +379,4 @@ https://openai.justsong.cn
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
同样适用于基于本项目的二开项目。
 | 
					同样适用于基于本项目的二开项目。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
 | 
					依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -56,6 +56,7 @@ var EmailDomainWhitelist = []string{
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
 | 
					var DebugEnabled = os.Getenv("DEBUG") == "true"
 | 
				
			||||||
 | 
					var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var LogConsumeEnabled = true
 | 
					var LogConsumeEnabled = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -92,7 +93,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 | 
				
			|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
					var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
				
			||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
					var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 | 
					var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var BatchUpdateEnabled = false
 | 
				
			||||||
 | 
					var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						RequestIdKey = "X-Oneapi-Request-Id"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	RoleGuestUser  = 0
 | 
						RoleGuestUser  = 0
 | 
				
			||||||
@@ -111,10 +119,10 @@ var (
 | 
				
			|||||||
// All duration's unit is seconds
 | 
					// All duration's unit is seconds
 | 
				
			||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
 | 
					// Shouldn't larger then RateLimitKeyExpirationDuration
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	GlobalApiRateLimitNum            = 180
 | 
						GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
 | 
				
			||||||
	GlobalApiRateLimitDuration int64 = 3 * 60
 | 
						GlobalApiRateLimitDuration int64 = 3 * 60
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	GlobalWebRateLimitNum            = 60
 | 
						GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
 | 
				
			||||||
	GlobalWebRateLimitDuration int64 = 3 * 60
 | 
						GlobalWebRateLimitDuration int64 = 3 * 60
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	UploadRateLimitNum            = 10
 | 
						UploadRateLimitNum            = 10
 | 
				
			||||||
@@ -154,45 +162,53 @@ const (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	ChannelTypeUnknown   = 0
 | 
						ChannelTypeUnknown        = 0
 | 
				
			||||||
	ChannelTypeOpenAI    = 1
 | 
						ChannelTypeOpenAI         = 1
 | 
				
			||||||
	ChannelTypeAPI2D     = 2
 | 
						ChannelTypeAPI2D          = 2
 | 
				
			||||||
	ChannelTypeAzure     = 3
 | 
						ChannelTypeAzure          = 3
 | 
				
			||||||
	ChannelTypeCloseAI   = 4
 | 
						ChannelTypeCloseAI        = 4
 | 
				
			||||||
	ChannelTypeOpenAISB  = 5
 | 
						ChannelTypeOpenAISB       = 5
 | 
				
			||||||
	ChannelTypeOpenAIMax = 6
 | 
						ChannelTypeOpenAIMax      = 6
 | 
				
			||||||
	ChannelTypeOhMyGPT   = 7
 | 
						ChannelTypeOhMyGPT        = 7
 | 
				
			||||||
	ChannelTypeCustom    = 8
 | 
						ChannelTypeCustom         = 8
 | 
				
			||||||
	ChannelTypeAILS      = 9
 | 
						ChannelTypeAILS           = 9
 | 
				
			||||||
	ChannelTypeAIProxy   = 10
 | 
						ChannelTypeAIProxy        = 10
 | 
				
			||||||
	ChannelTypePaLM      = 11
 | 
						ChannelTypePaLM           = 11
 | 
				
			||||||
	ChannelTypeAPI2GPT   = 12
 | 
						ChannelTypeAPI2GPT        = 12
 | 
				
			||||||
	ChannelTypeAIGC2D    = 13
 | 
						ChannelTypeAIGC2D         = 13
 | 
				
			||||||
	ChannelTypeAnthropic = 14
 | 
						ChannelTypeAnthropic      = 14
 | 
				
			||||||
	ChannelTypeBaidu     = 15
 | 
						ChannelTypeBaidu          = 15
 | 
				
			||||||
	ChannelTypeZhipu     = 16
 | 
						ChannelTypeZhipu          = 16
 | 
				
			||||||
	ChannelTypeAli       = 17
 | 
						ChannelTypeAli            = 17
 | 
				
			||||||
	ChannelTypeXunfei    = 18
 | 
						ChannelTypeXunfei         = 18
 | 
				
			||||||
 | 
						ChannelType360            = 19
 | 
				
			||||||
 | 
						ChannelTypeOpenRouter     = 20
 | 
				
			||||||
 | 
						ChannelTypeAIProxyLibrary = 21
 | 
				
			||||||
 | 
						ChannelTypeFastGPT        = 22
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var ChannelBaseURLs = []string{
 | 
					var ChannelBaseURLs = []string{
 | 
				
			||||||
	"",                               // 0
 | 
						"",                                // 0
 | 
				
			||||||
	"https://api.openai.com",         // 1
 | 
						"https://api.openai.com",          // 1
 | 
				
			||||||
	"https://oa.api2d.net",           // 2
 | 
						"https://oa.api2d.net",            // 2
 | 
				
			||||||
	"",                               // 3
 | 
						"",                                // 3
 | 
				
			||||||
	"https://api.closeai-proxy.xyz",  // 4
 | 
						"https://api.closeai-proxy.xyz",   // 4
 | 
				
			||||||
	"https://api.openai-sb.com",      // 5
 | 
						"https://api.openai-sb.com",       // 5
 | 
				
			||||||
	"https://api.openaimax.com",      // 6
 | 
						"https://api.openaimax.com",       // 6
 | 
				
			||||||
	"https://api.ohmygpt.com",        // 7
 | 
						"https://api.ohmygpt.com",         // 7
 | 
				
			||||||
	"",                               // 8
 | 
						"",                                // 8
 | 
				
			||||||
	"https://api.caipacity.com",      // 9
 | 
						"https://api.caipacity.com",       // 9
 | 
				
			||||||
	"https://api.aiproxy.io",         // 10
 | 
						"https://api.aiproxy.io",          // 10
 | 
				
			||||||
	"",                               // 11
 | 
						"",                                // 11
 | 
				
			||||||
	"https://api.api2gpt.com",        // 12
 | 
						"https://api.api2gpt.com",         // 12
 | 
				
			||||||
	"https://api.aigc2d.com",         // 13
 | 
						"https://api.aigc2d.com",          // 13
 | 
				
			||||||
	"https://api.anthropic.com",      // 14
 | 
						"https://api.anthropic.com",       // 14
 | 
				
			||||||
	"https://aip.baidubce.com",       // 15
 | 
						"https://aip.baidubce.com",        // 15
 | 
				
			||||||
	"https://open.bigmodel.cn",       // 16
 | 
						"https://open.bigmodel.cn",        // 16
 | 
				
			||||||
	"https://dashscope.aliyuncs.com", // 17
 | 
						"https://dashscope.aliyuncs.com",  // 17
 | 
				
			||||||
	"",                               // 18
 | 
						"",                                // 18
 | 
				
			||||||
 | 
						"https://ai.360.cn",               // 19
 | 
				
			||||||
 | 
						"https://openrouter.ai/api",       // 20
 | 
				
			||||||
 | 
						"https://api.aiproxy.io",          // 21
 | 
				
			||||||
 | 
						"https://fastgpt.run/api/openapi", // 22
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ var (
 | 
				
			|||||||
	Port         = flag.Int("port", 3000, "the listening port")
 | 
						Port         = flag.Int("port", 3000, "the listening port")
 | 
				
			||||||
	PrintVersion = flag.Bool("version", false, "print version and exit")
 | 
						PrintVersion = flag.Bool("version", false, "print version and exit")
 | 
				
			||||||
	PrintHelp    = flag.Bool("help", false, "print help and exit")
 | 
						PrintHelp    = flag.Bool("help", false, "print help and exit")
 | 
				
			||||||
	LogDir       = flag.String("log-dir", "", "specify the log directory")
 | 
						LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func printHelp() {
 | 
					func printHelp() {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,29 +1,47 @@
 | 
				
			|||||||
package common
 | 
					package common
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SetupGinLog() {
 | 
					const (
 | 
				
			||||||
 | 
						loggerINFO  = "INFO"
 | 
				
			||||||
 | 
						loggerWarn  = "WARN"
 | 
				
			||||||
 | 
						loggerError = "ERR"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const maxLogCount = 1000000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var logCount int
 | 
				
			||||||
 | 
					var setupLogLock sync.Mutex
 | 
				
			||||||
 | 
					var setupLogWorking bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SetupLogger() {
 | 
				
			||||||
	if *LogDir != "" {
 | 
						if *LogDir != "" {
 | 
				
			||||||
		commonLogPath := filepath.Join(*LogDir, "common.log")
 | 
							ok := setupLogLock.TryLock()
 | 
				
			||||||
		errorLogPath := filepath.Join(*LogDir, "error.log")
 | 
							if !ok {
 | 
				
			||||||
		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
								log.Println("setup log is already working")
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer func() {
 | 
				
			||||||
 | 
								setupLogLock.Unlock()
 | 
				
			||||||
 | 
								setupLogWorking = false
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
							logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
 | 
				
			||||||
 | 
							fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatal("failed to open log file")
 | 
								log.Fatal("failed to open log file")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
							gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
 | 
				
			||||||
		if err != nil {
 | 
							gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
 | 
				
			||||||
			log.Fatal("failed to open log file")
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
 | 
					 | 
				
			||||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -37,6 +55,36 @@ func SysError(s string) {
 | 
				
			|||||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
						_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func LogInfo(ctx context.Context, msg string) {
 | 
				
			||||||
 | 
						logHelper(ctx, loggerINFO, msg)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func LogWarn(ctx context.Context, msg string) {
 | 
				
			||||||
 | 
						logHelper(ctx, loggerWarn, msg)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func LogError(ctx context.Context, msg string) {
 | 
				
			||||||
 | 
						logHelper(ctx, loggerError, msg)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func logHelper(ctx context.Context, level string, msg string) {
 | 
				
			||||||
 | 
						writer := gin.DefaultErrorWriter
 | 
				
			||||||
 | 
						if level == loggerINFO {
 | 
				
			||||||
 | 
							writer = gin.DefaultWriter
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						id := ctx.Value(RequestIdKey)
 | 
				
			||||||
 | 
						now := time.Now()
 | 
				
			||||||
 | 
						_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
 | 
				
			||||||
 | 
						logCount++ // we don't need accurate count, so no lock here
 | 
				
			||||||
 | 
						if logCount > maxLogCount && !setupLogWorking {
 | 
				
			||||||
 | 
							logCount = 0
 | 
				
			||||||
 | 
							setupLogWorking = true
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								SetupLogger()
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func FatalLog(v ...any) {
 | 
					func FatalLog(v ...any) {
 | 
				
			||||||
	t := time.Now()
 | 
						t := time.Now()
 | 
				
			||||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 | 
						_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,46 +13,53 @@ import (
 | 
				
			|||||||
// 1 === $0.002 / 1K tokens
 | 
					// 1 === $0.002 / 1K tokens
 | 
				
			||||||
// 1 === ¥0.014 / 1k tokens
 | 
					// 1 === ¥0.014 / 1k tokens
 | 
				
			||||||
var ModelRatio = map[string]float64{
 | 
					var ModelRatio = map[string]float64{
 | 
				
			||||||
	"gpt-4":                   15,
 | 
						"gpt-4":                     15,
 | 
				
			||||||
	"gpt-4-0314":              15,
 | 
						"gpt-4-0314":                15,
 | 
				
			||||||
	"gpt-4-0613":              15,
 | 
						"gpt-4-0613":                15,
 | 
				
			||||||
	"gpt-4-32k":               30,
 | 
						"gpt-4-32k":                 30,
 | 
				
			||||||
	"gpt-4-32k-0314":          30,
 | 
						"gpt-4-32k-0314":            30,
 | 
				
			||||||
	"gpt-4-32k-0613":          30,
 | 
						"gpt-4-32k-0613":            30,
 | 
				
			||||||
	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens
 | 
						"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens
 | 
				
			||||||
	"gpt-3.5-turbo-0301":      0.75,
 | 
						"gpt-3.5-turbo-0301":        0.75,
 | 
				
			||||||
	"gpt-3.5-turbo-0613":      0.75,
 | 
						"gpt-3.5-turbo-0613":        0.75,
 | 
				
			||||||
	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens
 | 
						"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens
 | 
				
			||||||
	"gpt-3.5-turbo-16k-0613":  1.5,
 | 
						"gpt-3.5-turbo-16k-0613":    1.5,
 | 
				
			||||||
	"text-ada-001":            0.2,
 | 
						"gpt-3.5-turbo-instruct":    0.75, // $0.0015 / 1K tokens
 | 
				
			||||||
	"text-babbage-001":        0.25,
 | 
						"text-ada-001":              0.2,
 | 
				
			||||||
	"text-curie-001":          1,
 | 
						"text-babbage-001":          0.25,
 | 
				
			||||||
	"text-davinci-002":        10,
 | 
						"text-curie-001":            1,
 | 
				
			||||||
	"text-davinci-003":        10,
 | 
						"text-davinci-002":          10,
 | 
				
			||||||
	"text-davinci-edit-001":   10,
 | 
						"text-davinci-003":          10,
 | 
				
			||||||
	"code-davinci-edit-001":   10,
 | 
						"text-davinci-edit-001":     10,
 | 
				
			||||||
	"whisper-1":               10,
 | 
						"code-davinci-edit-001":     10,
 | 
				
			||||||
	"davinci":                 10,
 | 
						"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
 | 
				
			||||||
	"curie":                   10,
 | 
						"davinci":                   10,
 | 
				
			||||||
	"babbage":                 10,
 | 
						"curie":                     10,
 | 
				
			||||||
	"ada":                     10,
 | 
						"babbage":                   10,
 | 
				
			||||||
	"text-embedding-ada-002":  0.05,
 | 
						"ada":                       10,
 | 
				
			||||||
	"text-search-ada-doc-001": 10,
 | 
						"text-embedding-ada-002":    0.05,
 | 
				
			||||||
	"text-moderation-stable":  0.1,
 | 
						"text-search-ada-doc-001":   10,
 | 
				
			||||||
	"text-moderation-latest":  0.1,
 | 
						"text-moderation-stable":    0.1,
 | 
				
			||||||
	"dall-e":                  8,
 | 
						"text-moderation-latest":    0.1,
 | 
				
			||||||
	"claude-instant-1":        0.815,  // $1.63 / 1M tokens
 | 
						"dall-e":                    8,
 | 
				
			||||||
	"claude-2":                5.51,   // $11.02 / 1M tokens
 | 
						"claude-instant-1":          0.815,  // $1.63 / 1M tokens
 | 
				
			||||||
	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 | 
						"claude-2":                  5.51,   // $11.02 / 1M tokens
 | 
				
			||||||
	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
 | 
						"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
 | 
				
			||||||
	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
 | 
						"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
 | 
				
			||||||
	"PaLM-2":                  1,
 | 
						"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 | 
				
			||||||
	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 | 
						"PaLM-2":                    1,
 | 
				
			||||||
	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 | 
						"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
				
			||||||
	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
 | 
						"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
				
			||||||
	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
 | 
						"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
 | 
				
			||||||
	"qwen-plus-v1":            0.5715, // Same as above
 | 
						"qwen-turbo":                0.8572, // ¥0.012 / 1k tokens
 | 
				
			||||||
	"SparkDesk":               0.8572, // TBD
 | 
						"qwen-plus":                 10,     // ¥0.14 / 1k tokens
 | 
				
			||||||
 | 
						"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
 | 
				
			||||||
 | 
						"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 | 
				
			||||||
 | 
						"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
				
			||||||
 | 
						"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
				
			||||||
 | 
						"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
				
			||||||
 | 
						"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
 | 
				
			||||||
 | 
						"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ModelRatio2JSONString() string {
 | 
					func ModelRatio2JSONString() string {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
 | 
				
			|||||||
	return time.Now().Unix()
 | 
						return time.Now().Unix()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetTimeString() string {
 | 
				
			||||||
 | 
						now := time.Now()
 | 
				
			||||||
 | 
						return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Max(a int, b int) int {
 | 
					func Max(a int, b int) int {
 | 
				
			||||||
	if a >= b {
 | 
						if a >= b {
 | 
				
			||||||
		return a
 | 
							return a
 | 
				
			||||||
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return num
 | 
						return num
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func MessageWithRequestId(message string, id string) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("%s (request id: %s)", message, id)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		openAIError := OpenAIError{
 | 
							openAIError := OpenAIError{
 | 
				
			||||||
			Message: err.Error(),
 | 
								Message: err.Error(),
 | 
				
			||||||
			Type:    "one_api_error",
 | 
								Type:    "upstream_error",
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"error": openAIError,
 | 
								"error": openAIError,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
 | 
					func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
 | 
				
			||||||
	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
 | 
						url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
 | 
				
			||||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
						body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
					func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
				
			||||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
						baseURL := common.ChannelBaseURLs[channel.Type]
 | 
				
			||||||
	if channel.BaseURL == "" {
 | 
						if channel.GetBaseURL() == "" {
 | 
				
			||||||
		channel.BaseURL = baseURL
 | 
							channel.BaseURL = &baseURL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	switch channel.Type {
 | 
						switch channel.Type {
 | 
				
			||||||
	case common.ChannelTypeOpenAI:
 | 
						case common.ChannelTypeOpenAI:
 | 
				
			||||||
		if channel.BaseURL != "" {
 | 
							if channel.GetBaseURL() != "" {
 | 
				
			||||||
			baseURL = channel.BaseURL
 | 
								baseURL = channel.GetBaseURL()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case common.ChannelTypeAzure:
 | 
						case common.ChannelTypeAzure:
 | 
				
			||||||
		return 0, errors.New("尚未实现")
 | 
							return 0, errors.New("尚未实现")
 | 
				
			||||||
	case common.ChannelTypeCustom:
 | 
						case common.ChannelTypeCustom:
 | 
				
			||||||
		baseURL = channel.BaseURL
 | 
							baseURL = channel.GetBaseURL()
 | 
				
			||||||
	case common.ChannelTypeCloseAI:
 | 
						case common.ChannelTypeCloseAI:
 | 
				
			||||||
		return updateChannelCloseAIBalance(channel)
 | 
							return updateChannelCloseAIBalance(channel)
 | 
				
			||||||
	case common.ChannelTypeOpenAISB:
 | 
						case common.ChannelTypeOpenAISB:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,7 +14,7 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
					func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 | 
				
			||||||
	switch channel.Type {
 | 
						switch channel.Type {
 | 
				
			||||||
	case common.ChannelTypePaLM:
 | 
						case common.ChannelTypePaLM:
 | 
				
			||||||
		fallthrough
 | 
							fallthrough
 | 
				
			||||||
@@ -24,19 +24,28 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
 | 
				
			|||||||
		fallthrough
 | 
							fallthrough
 | 
				
			||||||
	case common.ChannelTypeZhipu:
 | 
						case common.ChannelTypeZhipu:
 | 
				
			||||||
		fallthrough
 | 
							fallthrough
 | 
				
			||||||
 | 
						case common.ChannelTypeAli:
 | 
				
			||||||
 | 
							fallthrough
 | 
				
			||||||
 | 
						case common.ChannelType360:
 | 
				
			||||||
 | 
							fallthrough
 | 
				
			||||||
	case common.ChannelTypeXunfei:
 | 
						case common.ChannelTypeXunfei:
 | 
				
			||||||
		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 | 
							return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 | 
				
			||||||
	case common.ChannelTypeAzure:
 | 
						case common.ChannelTypeAzure:
 | 
				
			||||||
		request.Model = "gpt-35-turbo"
 | 
							request.Model = "gpt-35-turbo"
 | 
				
			||||||
 | 
							defer func() {
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		request.Model = "gpt-3.5-turbo"
 | 
							request.Model = "gpt-3.5-turbo"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	requestURL := common.ChannelBaseURLs[channel.Type]
 | 
						requestURL := common.ChannelBaseURLs[channel.Type]
 | 
				
			||||||
	if channel.Type == common.ChannelTypeAzure {
 | 
						if channel.Type == common.ChannelTypeAzure {
 | 
				
			||||||
		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
 | 
							requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		if channel.BaseURL != "" {
 | 
							if channel.GetBaseURL() != "" {
 | 
				
			||||||
			requestURL = channel.BaseURL
 | 
								requestURL = channel.GetBaseURL()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		requestURL += "/v1/chat/completions"
 | 
							requestURL += "/v1/chat/completions"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	channel.CreatedTime = common.GetTimestamp()
 | 
						channel.CreatedTime = common.GetTimestamp()
 | 
				
			||||||
	keys := strings.Split(channel.Key, "\n")
 | 
						keys := strings.Split(channel.Key, "\n")
 | 
				
			||||||
	channels := make([]model.Channel, 0)
 | 
						channels := make([]model.Channel, 0, len(keys))
 | 
				
			||||||
	for _, key := range keys {
 | 
						for _, key := range keys {
 | 
				
			||||||
		if key == "" {
 | 
							if key == "" {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func GitHubOAuth(c *gin.Context) {
 | 
					func GitHubOAuth(c *gin.Context) {
 | 
				
			||||||
	session := sessions.Default(c)
 | 
						session := sessions.Default(c)
 | 
				
			||||||
 | 
						state := c.Query("state")
 | 
				
			||||||
 | 
						if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
 | 
				
			||||||
 | 
							c.JSON(http.StatusForbidden, gin.H{
 | 
				
			||||||
 | 
								"success": false,
 | 
				
			||||||
 | 
								"message": "state is empty or not same",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	username := session.Get("username")
 | 
						username := session.Get("username")
 | 
				
			||||||
	if username != nil {
 | 
						if username != nil {
 | 
				
			||||||
		GitHubBind(c)
 | 
							GitHubBind(c)
 | 
				
			||||||
@@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GenerateOAuthCode(c *gin.Context) {
 | 
				
			||||||
 | 
						session := sessions.Default(c)
 | 
				
			||||||
 | 
						state := common.GetRandomString(12)
 | 
				
			||||||
 | 
						session.Set("oauth_state", state)
 | 
				
			||||||
 | 
						err := session.Save()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
								"success": false,
 | 
				
			||||||
 | 
								"message": err.Error(),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
							"success": true,
 | 
				
			||||||
 | 
							"message": "",
 | 
				
			||||||
 | 
							"data":    state,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package controller
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"one-api/model"
 | 
						"one-api/model"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
@@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
 | 
				
			|||||||
	username := c.Query("username")
 | 
						username := c.Query("username")
 | 
				
			||||||
	tokenName := c.Query("token_name")
 | 
						tokenName := c.Query("token_name")
 | 
				
			||||||
	modelName := c.Query("model_name")
 | 
						modelName := c.Query("model_name")
 | 
				
			||||||
	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
						channel, _ := strconv.Atoi(c.Query("channel"))
 | 
				
			||||||
 | 
						logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
			"success": false,
 | 
								"success": false,
 | 
				
			||||||
			"message": err.Error(),
 | 
								"message": err.Error(),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.JSON(200, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"success": true,
 | 
							"success": true,
 | 
				
			||||||
		"message": "",
 | 
							"message": "",
 | 
				
			||||||
		"data":    logs,
 | 
							"data":    logs,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetUserLogs(c *gin.Context) {
 | 
					func GetUserLogs(c *gin.Context) {
 | 
				
			||||||
@@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
 | 
				
			|||||||
	modelName := c.Query("model_name")
 | 
						modelName := c.Query("model_name")
 | 
				
			||||||
	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
						logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
			"success": false,
 | 
								"success": false,
 | 
				
			||||||
			"message": err.Error(),
 | 
								"message": err.Error(),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.JSON(200, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"success": true,
 | 
							"success": true,
 | 
				
			||||||
		"message": "",
 | 
							"message": "",
 | 
				
			||||||
		"data":    logs,
 | 
							"data":    logs,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SearchAllLogs(c *gin.Context) {
 | 
					func SearchAllLogs(c *gin.Context) {
 | 
				
			||||||
	keyword := c.Query("keyword")
 | 
						keyword := c.Query("keyword")
 | 
				
			||||||
	logs, err := model.SearchAllLogs(keyword)
 | 
						logs, err := model.SearchAllLogs(keyword)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
			"success": false,
 | 
								"success": false,
 | 
				
			||||||
			"message": err.Error(),
 | 
								"message": err.Error(),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.JSON(200, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"success": true,
 | 
							"success": true,
 | 
				
			||||||
		"message": "",
 | 
							"message": "",
 | 
				
			||||||
		"data":    logs,
 | 
							"data":    logs,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SearchUserLogs(c *gin.Context) {
 | 
					func SearchUserLogs(c *gin.Context) {
 | 
				
			||||||
@@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
 | 
				
			|||||||
	userId := c.GetInt("id")
 | 
						userId := c.GetInt("id")
 | 
				
			||||||
	logs, err := model.SearchUserLogs(userId, keyword)
 | 
						logs, err := model.SearchUserLogs(userId, keyword)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
			"success": false,
 | 
								"success": false,
 | 
				
			||||||
			"message": err.Error(),
 | 
								"message": err.Error(),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	c.JSON(200, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"success": true,
 | 
							"success": true,
 | 
				
			||||||
		"message": "",
 | 
							"message": "",
 | 
				
			||||||
		"data":    logs,
 | 
							"data":    logs,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetLogsStat(c *gin.Context) {
 | 
					func GetLogsStat(c *gin.Context) {
 | 
				
			||||||
@@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
 | 
				
			|||||||
	tokenName := c.Query("token_name")
 | 
						tokenName := c.Query("token_name")
 | 
				
			||||||
	username := c.Query("username")
 | 
						username := c.Query("username")
 | 
				
			||||||
	modelName := c.Query("model_name")
 | 
						modelName := c.Query("model_name")
 | 
				
			||||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
						channel, _ := strconv.Atoi(c.Query("channel"))
 | 
				
			||||||
 | 
						quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 | 
				
			||||||
	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
 | 
						//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
 | 
				
			||||||
	c.JSON(200, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"success": true,
 | 
							"success": true,
 | 
				
			||||||
		"message": "",
 | 
							"message": "",
 | 
				
			||||||
		"data": gin.H{
 | 
							"data": gin.H{
 | 
				
			||||||
@@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
 | 
				
			|||||||
			//"token": tokenNum,
 | 
								//"token": tokenNum,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetLogsSelfStat(c *gin.Context) {
 | 
					func GetLogsSelfStat(c *gin.Context) {
 | 
				
			||||||
@@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
 | 
				
			|||||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
						endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
				
			||||||
	tokenName := c.Query("token_name")
 | 
						tokenName := c.Query("token_name")
 | 
				
			||||||
	modelName := c.Query("model_name")
 | 
						modelName := c.Query("model_name")
 | 
				
			||||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
						channel, _ := strconv.Atoi(c.Query("channel"))
 | 
				
			||||||
 | 
						quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 | 
				
			||||||
	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
						//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
				
			||||||
	c.JSON(200, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"success": true,
 | 
							"success": true,
 | 
				
			||||||
		"message": "",
 | 
							"message": "",
 | 
				
			||||||
		"data": gin.H{
 | 
							"data": gin.H{
 | 
				
			||||||
@@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
 | 
				
			|||||||
			//"token": tokenNum,
 | 
								//"token": tokenNum,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func DeleteHistoryLogs(c *gin.Context) {
 | 
				
			||||||
 | 
						targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
 | 
				
			||||||
 | 
						if targetTimestamp == 0 {
 | 
				
			||||||
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
								"success": false,
 | 
				
			||||||
 | 
								"message": "target timestamp is required",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						count, err := model.DeleteOldLog(targetTimestamp)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
								"success": false,
 | 
				
			||||||
 | 
								"message": err.Error(),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
							"success": true,
 | 
				
			||||||
 | 
							"message": "",
 | 
				
			||||||
 | 
							"data":    count,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -63,6 +63,15 @@ func init() {
 | 
				
			|||||||
			Root:       "dall-e",
 | 
								Root:       "dall-e",
 | 
				
			||||||
			Parent:     nil,
 | 
								Parent:     nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "whisper-1",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "whisper-1",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			Id:         "gpt-3.5-turbo",
 | 
								Id:         "gpt-3.5-turbo",
 | 
				
			||||||
			Object:     "model",
 | 
								Object:     "model",
 | 
				
			||||||
@@ -108,6 +117,15 @@ func init() {
 | 
				
			|||||||
			Root:       "gpt-3.5-turbo-16k-0613",
 | 
								Root:       "gpt-3.5-turbo-16k-0613",
 | 
				
			||||||
			Parent:     nil,
 | 
								Parent:     nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-3.5-turbo-instruct",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-3.5-turbo-instruct",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			Id:         "gpt-4",
 | 
								Id:         "gpt-4",
 | 
				
			||||||
			Object:     "model",
 | 
								Object:     "model",
 | 
				
			||||||
@@ -334,21 +352,30 @@ func init() {
 | 
				
			|||||||
			Parent:     nil,
 | 
								Parent:     nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			Id:         "qwen-v1",
 | 
								Id:         "qwen-turbo",
 | 
				
			||||||
			Object:     "model",
 | 
								Object:     "model",
 | 
				
			||||||
			Created:    1677649963,
 | 
								Created:    1677649963,
 | 
				
			||||||
			OwnedBy:    "ali",
 | 
								OwnedBy:    "ali",
 | 
				
			||||||
			Permission: permission,
 | 
								Permission: permission,
 | 
				
			||||||
			Root:       "qwen-v1",
 | 
								Root:       "qwen-turbo",
 | 
				
			||||||
			Parent:     nil,
 | 
								Parent:     nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			Id:         "qwen-plus-v1",
 | 
								Id:         "qwen-plus",
 | 
				
			||||||
			Object:     "model",
 | 
								Object:     "model",
 | 
				
			||||||
			Created:    1677649963,
 | 
								Created:    1677649963,
 | 
				
			||||||
			OwnedBy:    "ali",
 | 
								OwnedBy:    "ali",
 | 
				
			||||||
			Permission: permission,
 | 
								Permission: permission,
 | 
				
			||||||
			Root:       "qwen-plus-v1",
 | 
								Root:       "qwen-plus",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "text-embedding-v1",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "ali",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "text-embedding-v1",
 | 
				
			||||||
			Parent:     nil,
 | 
								Parent:     nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
@@ -360,6 +387,51 @@ func init() {
 | 
				
			|||||||
			Root:       "SparkDesk",
 | 
								Root:       "SparkDesk",
 | 
				
			||||||
			Parent:     nil,
 | 
								Parent:     nil,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "360GPT_S2_V9",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "360",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "360GPT_S2_V9",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "embedding-bert-512-v1",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "360",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "embedding-bert-512-v1",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "embedding_s1_v1",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "360",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "embedding_s1_v1",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "semantic_similarity_s1_v1",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "360",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "semantic_similarity_s1_v1",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "360GPT_S2_V9.4",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "360",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "360GPT_S2_V9.4",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
						openAIModelsMap = make(map[string]OpenAIModels)
 | 
				
			||||||
	for _, model := range openAIModels {
 | 
						for _, model := range openAIModels {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,220 @@
 | 
				
			|||||||
 | 
					package controller
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AIProxyLibraryRequest struct {
 | 
				
			||||||
 | 
						Model     string `json:"model"`
 | 
				
			||||||
 | 
						Query     string `json:"query"`
 | 
				
			||||||
 | 
						LibraryId string `json:"libraryId"`
 | 
				
			||||||
 | 
						Stream    bool   `json:"stream"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AIProxyLibraryError struct {
 | 
				
			||||||
 | 
						ErrCode int    `json:"errCode"`
 | 
				
			||||||
 | 
						Message string `json:"message"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AIProxyLibraryDocument struct {
 | 
				
			||||||
 | 
						Title string `json:"title"`
 | 
				
			||||||
 | 
						URL   string `json:"url"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AIProxyLibraryResponse struct {
 | 
				
			||||||
 | 
						Success   bool                     `json:"success"`
 | 
				
			||||||
 | 
						Answer    string                   `json:"answer"`
 | 
				
			||||||
 | 
						Documents []AIProxyLibraryDocument `json:"documents"`
 | 
				
			||||||
 | 
						AIProxyLibraryError
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AIProxyLibraryStreamResponse struct {
 | 
				
			||||||
 | 
						Content   string                   `json:"content"`
 | 
				
			||||||
 | 
						Finish    bool                     `json:"finish"`
 | 
				
			||||||
 | 
						Model     string                   `json:"model"`
 | 
				
			||||||
 | 
						Documents []AIProxyLibraryDocument `json:"documents"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
 | 
				
			||||||
 | 
						query := ""
 | 
				
			||||||
 | 
						if len(request.Messages) != 0 {
 | 
				
			||||||
 | 
							query = request.Messages[len(request.Messages)-1].Content
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &AIProxyLibraryRequest{
 | 
				
			||||||
 | 
							Model:  request.Model,
 | 
				
			||||||
 | 
							Stream: request.Stream,
 | 
				
			||||||
 | 
							Query:  query,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
 | 
				
			||||||
 | 
						if len(documents) == 0 {
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						content := "\n\n参考文档:\n"
 | 
				
			||||||
 | 
						for i, document := range documents {
 | 
				
			||||||
 | 
							content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return content
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
 | 
				
			||||||
 | 
						content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
 | 
				
			||||||
 | 
						choice := OpenAITextResponseChoice{
 | 
				
			||||||
 | 
							Index: 0,
 | 
				
			||||||
 | 
							Message: Message{
 | 
				
			||||||
 | 
								Role:    "assistant",
 | 
				
			||||||
 | 
								Content: content,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							FinishReason: "stop",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						fullTextResponse := OpenAITextResponse{
 | 
				
			||||||
 | 
							Id:      common.GetUUID(),
 | 
				
			||||||
 | 
							Object:  "chat.completion",
 | 
				
			||||||
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
 | 
							Choices: []OpenAITextResponseChoice{choice},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &fullTextResponse
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
 | 
				
			||||||
 | 
						var choice ChatCompletionsStreamResponseChoice
 | 
				
			||||||
 | 
						choice.Delta.Content = aiProxyDocuments2Markdown(documents)
 | 
				
			||||||
 | 
						choice.FinishReason = &stopFinishReason
 | 
				
			||||||
 | 
						return &ChatCompletionsStreamResponse{
 | 
				
			||||||
 | 
							Id:      common.GetUUID(),
 | 
				
			||||||
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
 | 
							Model:   "",
 | 
				
			||||||
 | 
							Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
 | 
				
			||||||
 | 
						var choice ChatCompletionsStreamResponseChoice
 | 
				
			||||||
 | 
						choice.Delta.Content = response.Content
 | 
				
			||||||
 | 
						return &ChatCompletionsStreamResponse{
 | 
				
			||||||
 | 
							Id:      common.GetUUID(),
 | 
				
			||||||
 | 
							Object:  "chat.completion.chunk",
 | 
				
			||||||
 | 
							Created: common.GetTimestamp(),
 | 
				
			||||||
 | 
							Model:   response.Model,
 | 
				
			||||||
 | 
							Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
				
			||||||
 | 
						var usage Usage
 | 
				
			||||||
 | 
						scanner := bufio.NewScanner(resp.Body)
 | 
				
			||||||
 | 
						scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
				
			||||||
 | 
							if atEOF && len(data) == 0 {
 | 
				
			||||||
 | 
								return 0, nil, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
				
			||||||
 | 
								return i + 1, data[0:i], nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if atEOF {
 | 
				
			||||||
 | 
								return len(data), data, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return 0, nil, nil
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						dataChan := make(chan string)
 | 
				
			||||||
 | 
						stopChan := make(chan bool)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for scanner.Scan() {
 | 
				
			||||||
 | 
								data := scanner.Text()
 | 
				
			||||||
 | 
								if len(data) < 5 { // ignore blank line or wrong format
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if data[:5] != "data:" {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								data = data[5:]
 | 
				
			||||||
 | 
								dataChan <- data
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							stopChan <- true
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						setEventStreamHeaders(c)
 | 
				
			||||||
 | 
						var documents []AIProxyLibraryDocument
 | 
				
			||||||
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case data := <-dataChan:
 | 
				
			||||||
 | 
								var AIProxyLibraryResponse AIProxyLibraryStreamResponse
 | 
				
			||||||
 | 
								err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									common.SysError("error unmarshalling stream response: " + err.Error())
 | 
				
			||||||
 | 
									return true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
				
			||||||
 | 
									documents = AIProxyLibraryResponse.Documents
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
				
			||||||
 | 
								jsonResponse, err := json.Marshal(response)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									common.SysError("error marshalling stream response: " + err.Error())
 | 
				
			||||||
 | 
									return true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							case <-stopChan:
 | 
				
			||||||
 | 
								response := documentsAIProxyLibrary(documents)
 | 
				
			||||||
 | 
								jsonResponse, err := json.Marshal(response)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									common.SysError("error marshalling stream response: " + err.Error())
 | 
				
			||||||
 | 
									return true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
				
			||||||
 | 
								c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						err := resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil, &usage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
				
			||||||
 | 
						var AIProxyLibraryResponse AIProxyLibraryResponse
 | 
				
			||||||
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if AIProxyLibraryResponse.ErrCode != 0 {
 | 
				
			||||||
 | 
							return &OpenAIErrorWithStatusCode{
 | 
				
			||||||
 | 
								OpenAIError: OpenAIError{
 | 
				
			||||||
 | 
									Message: AIProxyLibraryResponse.Message,
 | 
				
			||||||
 | 
									Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
 | 
				
			||||||
 | 
									Code:    AIProxyLibraryResponse.ErrCode,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								StatusCode: resp.StatusCode,
 | 
				
			||||||
 | 
							}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
				
			||||||
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
 | 
						_, err = c.Writer.Write(jsonResponse)
 | 
				
			||||||
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -35,6 +35,29 @@ type AliChatRequest struct {
 | 
				
			|||||||
	Parameters AliParameters `json:"parameters,omitempty"`
 | 
						Parameters AliParameters `json:"parameters,omitempty"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AliEmbeddingRequest struct {
 | 
				
			||||||
 | 
						Model string `json:"model"`
 | 
				
			||||||
 | 
						Input struct {
 | 
				
			||||||
 | 
							Texts []string `json:"texts"`
 | 
				
			||||||
 | 
						} `json:"input"`
 | 
				
			||||||
 | 
						Parameters *struct {
 | 
				
			||||||
 | 
							TextType string `json:"text_type,omitempty"`
 | 
				
			||||||
 | 
						} `json:"parameters,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AliEmbedding struct {
 | 
				
			||||||
 | 
						Embedding []float64 `json:"embedding"`
 | 
				
			||||||
 | 
						TextIndex int       `json:"text_index"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AliEmbeddingResponse struct {
 | 
				
			||||||
 | 
						Output struct {
 | 
				
			||||||
 | 
							Embeddings []AliEmbedding `json:"embeddings"`
 | 
				
			||||||
 | 
						} `json:"output"`
 | 
				
			||||||
 | 
						Usage AliUsage `json:"usage"`
 | 
				
			||||||
 | 
						AliError
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AliError struct {
 | 
					type AliError struct {
 | 
				
			||||||
	Code      string `json:"code"`
 | 
						Code      string `json:"code"`
 | 
				
			||||||
	Message   string `json:"message"`
 | 
						Message   string `json:"message"`
 | 
				
			||||||
@@ -44,6 +67,7 @@ type AliError struct {
 | 
				
			|||||||
type AliUsage struct {
 | 
					type AliUsage struct {
 | 
				
			||||||
	InputTokens  int `json:"input_tokens"`
 | 
						InputTokens  int `json:"input_tokens"`
 | 
				
			||||||
	OutputTokens int `json:"output_tokens"`
 | 
						OutputTokens int `json:"output_tokens"`
 | 
				
			||||||
 | 
						TotalTokens  int `json:"total_tokens"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AliOutput struct {
 | 
					type AliOutput struct {
 | 
				
			||||||
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
 | 
				
			||||||
 | 
						return &AliEmbeddingRequest{
 | 
				
			||||||
 | 
							Model: "text-embedding-v1",
 | 
				
			||||||
 | 
							Input: struct {
 | 
				
			||||||
 | 
								Texts []string `json:"texts"`
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								Texts: request.ParseInput(),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
				
			||||||
 | 
						var aliResponse AliEmbeddingResponse
 | 
				
			||||||
 | 
						err := json.NewDecoder(resp.Body).Decode(&aliResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if aliResponse.Code != "" {
 | 
				
			||||||
 | 
							return &OpenAIErrorWithStatusCode{
 | 
				
			||||||
 | 
								OpenAIError: OpenAIError{
 | 
				
			||||||
 | 
									Message: aliResponse.Message,
 | 
				
			||||||
 | 
									Type:    aliResponse.Code,
 | 
				
			||||||
 | 
									Param:   aliResponse.RequestId,
 | 
				
			||||||
 | 
									Code:    aliResponse.Code,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								StatusCode: resp.StatusCode,
 | 
				
			||||||
 | 
							}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 | 
				
			||||||
 | 
						jsonResponse, err := json.Marshal(fullTextResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
 | 
						_, err = c.Writer.Write(jsonResponse)
 | 
				
			||||||
 | 
						return nil, &fullTextResponse.Usage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
				
			||||||
 | 
						openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
				
			||||||
 | 
							Object: "list",
 | 
				
			||||||
 | 
							Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 | 
				
			||||||
 | 
							Model:  "text-embedding-v1",
 | 
				
			||||||
 | 
							Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, item := range response.Output.Embeddings {
 | 
				
			||||||
 | 
							openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
				
			||||||
 | 
								Object:    `embedding`,
 | 
				
			||||||
 | 
								Index:     item.TextIndex,
 | 
				
			||||||
 | 
								Embedding: item.Embedding,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &openAIEmbeddingResponse
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
					func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
				
			||||||
	choice := OpenAITextResponseChoice{
 | 
						choice := OpenAITextResponseChoice{
 | 
				
			||||||
		Index: 0,
 | 
							Index: 0,
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										149
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,149 @@
 | 
				
			|||||||
 | 
					package controller
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"one-api/model"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			||||||
 | 
						audioModel := "whisper-1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tokenId := c.GetInt("token_id")
 | 
				
			||||||
 | 
						channelType := c.GetInt("channel")
 | 
				
			||||||
 | 
						channelId := c.GetInt("channel_id")
 | 
				
			||||||
 | 
						userId := c.GetInt("id")
 | 
				
			||||||
 | 
						group := c.GetString("group")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						preConsumedTokens := common.PreConsumedQuota
 | 
				
			||||||
 | 
						modelRatio := common.GetModelRatio(audioModel)
 | 
				
			||||||
 | 
						groupRatio := common.GetGroupRatio(group)
 | 
				
			||||||
 | 
						ratio := modelRatio * groupRatio
 | 
				
			||||||
 | 
						preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
				
			||||||
 | 
						userQuota, err := model.CacheGetUserQuota(userId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if userQuota > 100*preConsumedQuota {
 | 
				
			||||||
 | 
							// in this case, we do not pre-consume quota
 | 
				
			||||||
 | 
							// because the user has enough quota
 | 
				
			||||||
 | 
							preConsumedQuota = 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if preConsumedQuota > 0 {
 | 
				
			||||||
 | 
							err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// map model name
 | 
				
			||||||
 | 
						modelMapping := c.GetString("model_mapping")
 | 
				
			||||||
 | 
						if modelMapping != "" {
 | 
				
			||||||
 | 
							modelMap := make(map[string]string)
 | 
				
			||||||
 | 
							err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if modelMap[audioModel] != "" {
 | 
				
			||||||
 | 
								audioModel = modelMap[audioModel]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						baseURL := common.ChannelBaseURLs[channelType]
 | 
				
			||||||
 | 
						requestURL := c.Request.URL.String()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if c.GetString("base_url") != "" {
 | 
				
			||||||
 | 
							baseURL = c.GetString("base_url")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
				
			||||||
 | 
						requestBody := c.Request.Body
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
				
			||||||
 | 
						req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
				
			||||||
 | 
						req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp, err := httpClient.Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = req.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = c.Request.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var audioResponse AudioResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func(ctx context.Context) {
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								quota := countTokenText(audioResponse.Text, audioModel)
 | 
				
			||||||
 | 
								quotaDelta := quota - preConsumedQuota
 | 
				
			||||||
 | 
								err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									common.SysError("error consuming token remain quota: " + err.Error())
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								err = model.CacheUpdateUserQuota(userId)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									common.SysError("error update user quota cache: " + err.Error())
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if quota != 0 {
 | 
				
			||||||
 | 
									tokenName := c.GetString("token_name")
 | 
				
			||||||
 | 
									logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
				
			||||||
 | 
									model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
				
			||||||
 | 
									model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
				
			||||||
 | 
									channelId := c.GetInt("channel_id")
 | 
				
			||||||
 | 
									model.UpdateChannelUsedQuota(channelId, quota)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}(c.Request.Context())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = json.Unmarshal(responseBody, &audioResponse)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for k, v := range resp.Header {
 | 
				
			||||||
 | 
							c.Writer.Header().Set(k, v[0])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = io.Copy(c.Writer, resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = resp.Body.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
					func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
				
			||||||
	baiduEmbeddingRequest := BaiduEmbeddingRequest{
 | 
						return &BaiduEmbeddingRequest{
 | 
				
			||||||
		Input: nil,
 | 
							Input: request.ParseInput(),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	switch request.Input.(type) {
 | 
					 | 
				
			||||||
	case string:
 | 
					 | 
				
			||||||
		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
 | 
					 | 
				
			||||||
	case []any:
 | 
					 | 
				
			||||||
		for _, item := range request.Input.([]any) {
 | 
					 | 
				
			||||||
			if str, ok := item.(string); ok {
 | 
					 | 
				
			||||||
				baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &baiduEmbeddingRequest
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
					func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package controller
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
@@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	tokenId := c.GetInt("token_id")
 | 
						tokenId := c.GetInt("token_id")
 | 
				
			||||||
	channelType := c.GetInt("channel")
 | 
						channelType := c.GetInt("channel")
 | 
				
			||||||
 | 
						channelId := c.GetInt("channel_id")
 | 
				
			||||||
	userId := c.GetInt("id")
 | 
						userId := c.GetInt("id")
 | 
				
			||||||
	consumeQuota := c.GetBool("consume_quota")
 | 
						consumeQuota := c.GetBool("consume_quota")
 | 
				
			||||||
	group := c.GetString("group")
 | 
						group := c.GetString("group")
 | 
				
			||||||
@@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	var textResponse ImageResponse
 | 
						var textResponse ImageResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func() {
 | 
						defer func(ctx context.Context) {
 | 
				
			||||||
		if consumeQuota {
 | 
							if consumeQuota {
 | 
				
			||||||
			err := model.PostConsumeTokenQuota(tokenId, quota)
 | 
								err := model.PostConsumeTokenQuota(tokenId, quota)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
@@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
				
			|||||||
			if quota != 0 {
 | 
								if quota != 0 {
 | 
				
			||||||
				tokenName := c.GetString("token_name")
 | 
									tokenName := c.GetString("token_name")
 | 
				
			||||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
									logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
				
			||||||
				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
									model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
				
			||||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
									model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
				
			||||||
				channelId := c.GetInt("channel_id")
 | 
									channelId := c.GetInt("channel_id")
 | 
				
			||||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
									model.UpdateChannelUsedQuota(channelId, quota)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}(c.Request.Context())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if consumeQuota {
 | 
						if consumeQuota {
 | 
				
			||||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
							responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package controller
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
@@ -22,6 +23,7 @@ const (
 | 
				
			|||||||
	APITypeZhipu
 | 
						APITypeZhipu
 | 
				
			||||||
	APITypeAli
 | 
						APITypeAli
 | 
				
			||||||
	APITypeXunfei
 | 
						APITypeXunfei
 | 
				
			||||||
 | 
						APITypeAIProxyLibrary
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var httpClient *http.Client
 | 
					var httpClient *http.Client
 | 
				
			||||||
@@ -36,6 +38,7 @@ func init() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
					func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			||||||
	channelType := c.GetInt("channel")
 | 
						channelType := c.GetInt("channel")
 | 
				
			||||||
 | 
						channelId := c.GetInt("channel_id")
 | 
				
			||||||
	tokenId := c.GetInt("token_id")
 | 
						tokenId := c.GetInt("token_id")
 | 
				
			||||||
	userId := c.GetInt("id")
 | 
						userId := c.GetInt("id")
 | 
				
			||||||
	consumeQuota := c.GetBool("consume_quota")
 | 
						consumeQuota := c.GetBool("consume_quota")
 | 
				
			||||||
@@ -104,6 +107,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		apiType = APITypeAli
 | 
							apiType = APITypeAli
 | 
				
			||||||
	case common.ChannelTypeXunfei:
 | 
						case common.ChannelTypeXunfei:
 | 
				
			||||||
		apiType = APITypeXunfei
 | 
							apiType = APITypeXunfei
 | 
				
			||||||
 | 
						case common.ChannelTypeAIProxyLibrary:
 | 
				
			||||||
 | 
							apiType = APITypeAIProxyLibrary
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
						baseURL := common.ChannelBaseURLs[channelType]
 | 
				
			||||||
	requestURL := c.Request.URL.String()
 | 
						requestURL := c.Request.URL.String()
 | 
				
			||||||
@@ -171,6 +176,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
							fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
				
			||||||
	case APITypeAli:
 | 
						case APITypeAli:
 | 
				
			||||||
		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
							fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
				
			||||||
 | 
							if relayMode == RelayModeEmbeddings {
 | 
				
			||||||
 | 
								fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case APITypeAIProxyLibrary:
 | 
				
			||||||
 | 
							fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var promptTokens int
 | 
						var promptTokens int
 | 
				
			||||||
	var completionTokens int
 | 
						var completionTokens int
 | 
				
			||||||
@@ -202,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		// in this case, we do not pre-consume quota
 | 
							// in this case, we do not pre-consume quota
 | 
				
			||||||
		// because the user has enough quota
 | 
							// because the user has enough quota
 | 
				
			||||||
		preConsumedQuota = 0
 | 
							preConsumedQuota = 0
 | 
				
			||||||
 | 
							common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if consumeQuota && preConsumedQuota > 0 {
 | 
						if consumeQuota && preConsumedQuota > 0 {
 | 
				
			||||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
							err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
				
			||||||
@@ -257,8 +268,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
	case APITypeAli:
 | 
						case APITypeAli:
 | 
				
			||||||
		aliRequest := requestOpenAI2Ali(textRequest)
 | 
							var jsonStr []byte
 | 
				
			||||||
		jsonStr, err := json.Marshal(aliRequest)
 | 
							var err error
 | 
				
			||||||
 | 
							switch relayMode {
 | 
				
			||||||
 | 
							case RelayModeEmbeddings:
 | 
				
			||||||
 | 
								aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
 | 
				
			||||||
 | 
								jsonStr, err = json.Marshal(aliEmbeddingRequest)
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								aliRequest := requestOpenAI2Ali(textRequest)
 | 
				
			||||||
 | 
								jsonStr, err = json.Marshal(aliRequest)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							requestBody = bytes.NewBuffer(jsonStr)
 | 
				
			||||||
 | 
						case APITypeAIProxyLibrary:
 | 
				
			||||||
 | 
							aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
 | 
				
			||||||
 | 
							aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
 | 
				
			||||||
 | 
							jsonStr, err := json.Marshal(aiProxyLibraryRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
								return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -282,6 +309,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
				req.Header.Set("api-key", apiKey)
 | 
									req.Header.Set("api-key", apiKey)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
									req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
				
			||||||
 | 
									if channelType == common.ChannelTypeOpenRouter {
 | 
				
			||||||
 | 
										req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
 | 
				
			||||||
 | 
										req.Header.Set("X-Title", "One API")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case APITypeClaude:
 | 
							case APITypeClaude:
 | 
				
			||||||
			req.Header.Set("x-api-key", apiKey)
 | 
								req.Header.Set("x-api-key", apiKey)
 | 
				
			||||||
@@ -298,6 +329,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			if textRequest.Stream {
 | 
								if textRequest.Stream {
 | 
				
			||||||
				req.Header.Set("X-DashScope-SSE", "enable")
 | 
									req.Header.Set("X-DashScope-SSE", "enable")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
							req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
				
			||||||
		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
							req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
				
			||||||
@@ -317,15 +350,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
							isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if resp.StatusCode != http.StatusOK {
 | 
							if resp.StatusCode != http.StatusOK {
 | 
				
			||||||
 | 
								if preConsumedQuota != 0 {
 | 
				
			||||||
 | 
									go func(ctx context.Context) {
 | 
				
			||||||
 | 
										// return pre-consumed quota
 | 
				
			||||||
 | 
										err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}(c.Request.Context())
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			return relayErrorHandler(resp)
 | 
								return relayErrorHandler(resp)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var textResponse TextResponse
 | 
						var textResponse TextResponse
 | 
				
			||||||
	tokenName := c.GetString("token_name")
 | 
						tokenName := c.GetString("token_name")
 | 
				
			||||||
	channelId := c.GetInt("channel_id")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func() {
 | 
						defer func(ctx context.Context) {
 | 
				
			||||||
		// c.Writer.Flush()
 | 
							// c.Writer.Flush()
 | 
				
			||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
			if consumeQuota {
 | 
								if consumeQuota {
 | 
				
			||||||
@@ -348,22 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
				quotaDelta := quota - preConsumedQuota
 | 
									quotaDelta := quota - preConsumedQuota
 | 
				
			||||||
				err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
									err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					common.SysError("error consuming token remain quota: " + err.Error())
 | 
										common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				err = model.CacheUpdateUserQuota(userId)
 | 
									err = model.CacheUpdateUserQuota(userId)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					common.SysError("error update user quota cache: " + err.Error())
 | 
										common.LogError(ctx, "error update user quota cache: "+err.Error())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if quota != 0 {
 | 
									if quota != 0 {
 | 
				
			||||||
					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
										logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
				
			||||||
					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
										model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
				
			||||||
					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
										model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
				
			||||||
 | 
					 | 
				
			||||||
					model.UpdateChannelUsedQuota(channelId, quota)
 | 
										model.UpdateChannelUsedQuota(channelId, quota)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
	}()
 | 
						}(c.Request.Context())
 | 
				
			||||||
	switch apiType {
 | 
						switch apiType {
 | 
				
			||||||
	case APITypeOpenAI:
 | 
						case APITypeOpenAI:
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
@@ -484,7 +524,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			err, usage := aliHandler(c, resp)
 | 
								var err *OpenAIErrorWithStatusCode
 | 
				
			||||||
 | 
								var usage *Usage
 | 
				
			||||||
 | 
								switch relayMode {
 | 
				
			||||||
 | 
								case RelayModeEmbeddings:
 | 
				
			||||||
 | 
									err, usage = aliEmbeddingHandler(c, resp)
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
									err, usage = aliHandler(c, resp)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -494,14 +541,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case APITypeXunfei:
 | 
						case APITypeXunfei:
 | 
				
			||||||
 | 
							auth := c.Request.Header.Get("Authorization")
 | 
				
			||||||
 | 
							auth = strings.TrimPrefix(auth, "Bearer ")
 | 
				
			||||||
 | 
							splits := strings.Split(auth, "|")
 | 
				
			||||||
 | 
							if len(splits) != 3 {
 | 
				
			||||||
 | 
								return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							var err *OpenAIErrorWithStatusCode
 | 
				
			||||||
 | 
							var usage *Usage
 | 
				
			||||||
		if isStream {
 | 
							if isStream {
 | 
				
			||||||
			auth := c.Request.Header.Get("Authorization")
 | 
								err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
				
			||||||
			auth = strings.TrimPrefix(auth, "Bearer ")
 | 
							} else {
 | 
				
			||||||
			splits := strings.Split(auth, "|")
 | 
								err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
				
			||||||
			if len(splits) != 3 {
 | 
							}
 | 
				
			||||||
				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
							if err != nil {
 | 
				
			||||||
			}
 | 
								return err
 | 
				
			||||||
			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
							}
 | 
				
			||||||
 | 
							if usage != nil {
 | 
				
			||||||
 | 
								textResponse.Usage = *usage
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						case APITypeAIProxyLibrary:
 | 
				
			||||||
 | 
							if isStream {
 | 
				
			||||||
 | 
								err, usage := aiProxyLibraryStreamHandler(c, resp)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -510,7 +572,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 | 
								err, usage := aiProxyLibraryHandler(c, resp)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if usage != nil {
 | 
				
			||||||
 | 
									textResponse.Usage = *usage
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
							return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,24 @@ var stopFinishReason = "stop"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
					var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func InitTokenEncoders() {
 | 
				
			||||||
 | 
						common.SysLog("initializing token encoders")
 | 
				
			||||||
 | 
						fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for model, _ := range common.ModelRatio {
 | 
				
			||||||
 | 
							tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
 | 
				
			||||||
 | 
								tokenEncoderMap[model] = fallbackTokenEncoder
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							tokenEncoderMap[model] = tokenEncoder
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						common.SysLog("token encoders initialized")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
					func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
				
			||||||
	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
						if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
				
			||||||
		return tokenEncoder
 | 
							return tokenEncoder
 | 
				
			||||||
@@ -128,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 | 
				
			|||||||
		StatusCode: resp.StatusCode,
 | 
							StatusCode: resp.StatusCode,
 | 
				
			||||||
		OpenAIError: OpenAIError{
 | 
							OpenAIError: OpenAIError{
 | 
				
			||||||
			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 | 
								Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 | 
				
			||||||
			Type:    "one_api_error",
 | 
								Type:    "upstream_error",
 | 
				
			||||||
			Code:    "bad_response_status_code",
 | 
								Code:    "bad_response_status_code",
 | 
				
			||||||
			Param:   strconv.Itoa(resp.StatusCode),
 | 
								Param:   strconv.Itoa(resp.StatusCode),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 | 
				
			|||||||
			Role:    "assistant",
 | 
								Role:    "assistant",
 | 
				
			||||||
			Content: response.Payload.Choices.Text[0].Content,
 | 
								Content: response.Payload.Choices.Text[0].Content,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							FinishReason: stopFinishReason,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fullTextResponse := OpenAITextResponse{
 | 
						fullTextResponse := OpenAITextResponse{
 | 
				
			||||||
		Object:  "chat.completion",
 | 
							Object:  "chat.completion",
 | 
				
			||||||
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
				
			||||||
 | 
						domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
				
			||||||
 | 
						dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						setEventStreamHeaders(c)
 | 
				
			||||||
	var usage Usage
 | 
						var usage Usage
 | 
				
			||||||
	query := c.Request.URL.Query()
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
	apiVersion := query.Get("api-version")
 | 
							select {
 | 
				
			||||||
	if apiVersion == "" {
 | 
							case xunfeiResponse := <-dataChan:
 | 
				
			||||||
		apiVersion = c.GetString("api_version")
 | 
								usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
				
			||||||
 | 
								usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
				
			||||||
 | 
								usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
				
			||||||
 | 
								response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
				
			||||||
 | 
								jsonResponse, err := json.Marshal(response)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									common.SysError("error marshalling stream response: " + err.Error())
 | 
				
			||||||
 | 
									return true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							case <-stopChan:
 | 
				
			||||||
 | 
								c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						return nil, &usage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
				
			||||||
 | 
						domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
				
			||||||
 | 
						dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if apiVersion == "" {
 | 
						var usage Usage
 | 
				
			||||||
		apiVersion = "v1.1"
 | 
						var content string
 | 
				
			||||||
		common.SysLog("api_version not found, use default: " + apiVersion)
 | 
						var xunfeiResponse XunfeiChatResponse
 | 
				
			||||||
 | 
						stop := false
 | 
				
			||||||
 | 
						for !stop {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case xunfeiResponse = <-dataChan:
 | 
				
			||||||
 | 
								content += xunfeiResponse.Payload.Choices.Text[0].Content
 | 
				
			||||||
 | 
								usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
				
			||||||
 | 
								usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
				
			||||||
 | 
								usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
				
			||||||
 | 
							case stop = <-stopChan:
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	domain := "general"
 | 
					
 | 
				
			||||||
	if apiVersion == "v2.1" {
 | 
						xunfeiResponse.Payload.Choices.Text[0].Content = content
 | 
				
			||||||
		domain = "generalv2"
 | 
					
 | 
				
			||||||
 | 
						response := responseXunfei2OpenAI(&xunfeiResponse)
 | 
				
			||||||
 | 
						jsonResponse, err := json.Marshal(response)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
 | 
						c.Writer.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
 | 
						_, _ = c.Writer.Write(jsonResponse)
 | 
				
			||||||
 | 
						return nil, &usage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 | 
				
			||||||
	d := websocket.Dialer{
 | 
						d := websocket.Dialer{
 | 
				
			||||||
		HandshakeTimeout: 5 * time.Second,
 | 
							HandshakeTimeout: 5 * time.Second,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
 | 
						conn, resp, err := d.Dial(authUrl, nil)
 | 
				
			||||||
	if err != nil || resp.StatusCode != 101 {
 | 
						if err != nil || resp.StatusCode != 101 {
 | 
				
			||||||
		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
 | 
							return nil, nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 | 
						data := requestOpenAI2Xunfei(textRequest, appId, domain)
 | 
				
			||||||
	err = conn.WriteJSON(data)
 | 
						err = conn.WriteJSON(data)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
 | 
							return nil, nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dataChan := make(chan XunfeiChatResponse)
 | 
						dataChan := make(chan XunfeiChatResponse)
 | 
				
			||||||
	stopChan := make(chan bool)
 | 
						stopChan := make(chan bool)
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		stopChan <- true
 | 
							stopChan <- true
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	setEventStreamHeaders(c)
 | 
					
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						return dataChan, stopChan, nil
 | 
				
			||||||
		select {
 | 
					 | 
				
			||||||
		case xunfeiResponse := <-dataChan:
 | 
					 | 
				
			||||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
					 | 
				
			||||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
					 | 
				
			||||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
					 | 
				
			||||||
			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
					 | 
				
			||||||
			jsonResponse, err := json.Marshal(response)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
					 | 
				
			||||||
				return true
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		case <-stopChan:
 | 
					 | 
				
			||||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
					 | 
				
			||||||
			return false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
	return nil, &usage
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
					func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
 | 
				
			||||||
	var xunfeiResponse XunfeiChatResponse
 | 
						query := c.Request.URL.Query()
 | 
				
			||||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
						apiVersion := query.Get("api-version")
 | 
				
			||||||
	if err != nil {
 | 
						if apiVersion == "" {
 | 
				
			||||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
							apiVersion = c.GetString("api_version")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = resp.Body.Close()
 | 
						if apiVersion == "" {
 | 
				
			||||||
	if err != nil {
 | 
							apiVersion = "v1.1"
 | 
				
			||||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
							common.SysLog("api_version not found, use default: " + apiVersion)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = json.Unmarshal(responseBody, &xunfeiResponse)
 | 
						domain := "general"
 | 
				
			||||||
	if err != nil {
 | 
						if apiVersion == "v2.1" {
 | 
				
			||||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
							domain = "generalv2"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if xunfeiResponse.Header.Code != 0 {
 | 
						authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
				
			||||||
		return &OpenAIErrorWithStatusCode{
 | 
						return domain, authUrl
 | 
				
			||||||
			OpenAIError: OpenAIError{
 | 
					 | 
				
			||||||
				Message: xunfeiResponse.Header.Message,
 | 
					 | 
				
			||||||
				Type:    "xunfei_error",
 | 
					 | 
				
			||||||
				Param:   "",
 | 
					 | 
				
			||||||
				Code:    xunfeiResponse.Header.Code,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
			StatusCode: resp.StatusCode,
 | 
					 | 
				
			||||||
		}, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
 | 
					 | 
				
			||||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
					 | 
				
			||||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
					 | 
				
			||||||
	_, err = c.Writer.Write(jsonResponse)
 | 
					 | 
				
			||||||
	return nil, &fullTextResponse.Usage
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,6 +24,7 @@ const (
 | 
				
			|||||||
	RelayModeModerations
 | 
						RelayModeModerations
 | 
				
			||||||
	RelayModeImagesGenerations
 | 
						RelayModeImagesGenerations
 | 
				
			||||||
	RelayModeEdits
 | 
						RelayModeEdits
 | 
				
			||||||
 | 
						RelayModeAudio
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// https://platform.openai.com/docs/api-reference/chat
 | 
					// https://platform.openai.com/docs/api-reference/chat
 | 
				
			||||||
@@ -40,6 +41,26 @@ type GeneralOpenAIRequest struct {
 | 
				
			|||||||
	Input       any       `json:"input,omitempty"`
 | 
						Input       any       `json:"input,omitempty"`
 | 
				
			||||||
	Instruction string    `json:"instruction,omitempty"`
 | 
						Instruction string    `json:"instruction,omitempty"`
 | 
				
			||||||
	Size        string    `json:"size,omitempty"`
 | 
						Size        string    `json:"size,omitempty"`
 | 
				
			||||||
 | 
						Functions   any       `json:"functions,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
				
			||||||
 | 
						if r.Input == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var input []string
 | 
				
			||||||
 | 
						switch r.Input.(type) {
 | 
				
			||||||
 | 
						case string:
 | 
				
			||||||
 | 
							input = []string{r.Input.(string)}
 | 
				
			||||||
 | 
						case []any:
 | 
				
			||||||
 | 
							input = make([]string, 0, len(r.Input.([]any)))
 | 
				
			||||||
 | 
							for _, item := range r.Input.([]any) {
 | 
				
			||||||
 | 
								if str, ok := item.(string); ok {
 | 
				
			||||||
 | 
									input = append(input, str)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return input
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ChatRequest struct {
 | 
					type ChatRequest struct {
 | 
				
			||||||
@@ -62,6 +83,10 @@ type ImageRequest struct {
 | 
				
			|||||||
	Size   string `json:"size"`
 | 
						Size   string `json:"size"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AudioResponse struct {
 | 
				
			||||||
 | 
						Text string `json:"text,omitempty"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Usage struct {
 | 
					type Usage struct {
 | 
				
			||||||
	PromptTokens     int `json:"prompt_tokens"`
 | 
						PromptTokens     int `json:"prompt_tokens"`
 | 
				
			||||||
	CompletionTokens int `json:"completion_tokens"`
 | 
						CompletionTokens int `json:"completion_tokens"`
 | 
				
			||||||
@@ -158,15 +183,20 @@ func Relay(c *gin.Context) {
 | 
				
			|||||||
		relayMode = RelayModeImagesGenerations
 | 
							relayMode = RelayModeImagesGenerations
 | 
				
			||||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 | 
				
			||||||
		relayMode = RelayModeEdits
 | 
							relayMode = RelayModeEdits
 | 
				
			||||||
 | 
						} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
				
			||||||
 | 
							relayMode = RelayModeAudio
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var err *OpenAIErrorWithStatusCode
 | 
						var err *OpenAIErrorWithStatusCode
 | 
				
			||||||
	switch relayMode {
 | 
						switch relayMode {
 | 
				
			||||||
	case RelayModeImagesGenerations:
 | 
						case RelayModeImagesGenerations:
 | 
				
			||||||
		err = relayImageHelper(c, relayMode)
 | 
							err = relayImageHelper(c, relayMode)
 | 
				
			||||||
 | 
						case RelayModeAudio:
 | 
				
			||||||
 | 
							err = relayAudioHelper(c, relayMode)
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		err = relayTextHelper(c, relayMode)
 | 
							err = relayTextHelper(c, relayMode)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
							requestId := c.GetString(common.RequestIdKey)
 | 
				
			||||||
		retryTimesStr := c.Query("retry")
 | 
							retryTimesStr := c.Query("retry")
 | 
				
			||||||
		retryTimes, _ := strconv.Atoi(retryTimesStr)
 | 
							retryTimes, _ := strconv.Atoi(retryTimesStr)
 | 
				
			||||||
		if retryTimesStr == "" {
 | 
							if retryTimesStr == "" {
 | 
				
			||||||
@@ -178,12 +208,13 @@ func Relay(c *gin.Context) {
 | 
				
			|||||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
								if err.StatusCode == http.StatusTooManyRequests {
 | 
				
			||||||
				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
									err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 | 
				
			||||||
			c.JSON(err.StatusCode, gin.H{
 | 
								c.JSON(err.StatusCode, gin.H{
 | 
				
			||||||
				"error": err.OpenAIError,
 | 
									"error": err.OpenAIError,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		channelId := c.GetInt("channel_id")
 | 
							channelId := c.GetInt("channel_id")
 | 
				
			||||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
							common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
				
			||||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
							// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
				
			||||||
		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
							if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
				
			||||||
			channelId := c.GetInt("channel_id")
 | 
								channelId := c.GetInt("channel_id")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -523,5 +523,6 @@
 | 
				
			|||||||
  "按照如下格式输入:": "Enter in the following format:",
 | 
					  "按照如下格式输入:": "Enter in the following format:",
 | 
				
			||||||
  "模型版本": "Model version",
 | 
					  "模型版本": "Model version",
 | 
				
			||||||
  "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
 | 
					  "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
 | 
				
			||||||
  "点击查看": "click to view"
 | 
					  "点击查看": "click to view",
 | 
				
			||||||
 | 
					  "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										35
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								main.go
									
									
									
									
									
								
							@@ -2,6 +2,7 @@ package main
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"embed"
 | 
						"embed"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"github.com/gin-contrib/sessions"
 | 
						"github.com/gin-contrib/sessions"
 | 
				
			||||||
	"github.com/gin-contrib/sessions/cookie"
 | 
						"github.com/gin-contrib/sessions/cookie"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
@@ -21,7 +22,7 @@ var buildFS embed.FS
 | 
				
			|||||||
var indexPage []byte
 | 
					var indexPage []byte
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
	common.SetupGinLog()
 | 
						common.SetupLogger()
 | 
				
			||||||
	common.SysLog("One API " + common.Version + " started")
 | 
						common.SysLog("One API " + common.Version + " started")
 | 
				
			||||||
	if os.Getenv("GIN_MODE") != "debug" {
 | 
						if os.Getenv("GIN_MODE") != "debug" {
 | 
				
			||||||
		gin.SetMode(gin.ReleaseMode)
 | 
							gin.SetMode(gin.ReleaseMode)
 | 
				
			||||||
@@ -50,18 +51,17 @@ func main() {
 | 
				
			|||||||
	// Initialize options
 | 
						// Initialize options
 | 
				
			||||||
	model.InitOptionMap()
 | 
						model.InitOptionMap()
 | 
				
			||||||
	if common.RedisEnabled {
 | 
						if common.RedisEnabled {
 | 
				
			||||||
 | 
							// for compatibility with old versions
 | 
				
			||||||
 | 
							common.MemoryCacheEnabled = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if common.MemoryCacheEnabled {
 | 
				
			||||||
 | 
							common.SysLog("memory cache enabled")
 | 
				
			||||||
 | 
							common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
 | 
				
			||||||
		model.InitChannelCache()
 | 
							model.InitChannelCache()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if os.Getenv("SYNC_FREQUENCY") != "" {
 | 
						if common.MemoryCacheEnabled {
 | 
				
			||||||
		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
 | 
							go model.SyncOptions(common.SyncFrequency)
 | 
				
			||||||
		if err != nil {
 | 
							go model.SyncChannelCache(common.SyncFrequency)
 | 
				
			||||||
			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		common.SyncFrequency = frequency
 | 
					 | 
				
			||||||
		go model.SyncOptions(frequency)
 | 
					 | 
				
			||||||
		if common.RedisEnabled {
 | 
					 | 
				
			||||||
			go model.SyncChannelCache(frequency)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
 | 
						if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
 | 
				
			||||||
		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
 | 
							frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
 | 
				
			||||||
@@ -77,13 +77,20 @@ func main() {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		go controller.AutomaticallyTestChannels(frequency)
 | 
							go controller.AutomaticallyTestChannels(frequency)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 | 
				
			||||||
 | 
							common.BatchUpdateEnabled = true
 | 
				
			||||||
 | 
							common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 | 
				
			||||||
 | 
							model.InitBatchUpdater()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						controller.InitTokenEncoders()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Initialize HTTP server
 | 
						// Initialize HTTP server
 | 
				
			||||||
	server := gin.Default()
 | 
						server := gin.New()
 | 
				
			||||||
 | 
						server.Use(gin.Recovery())
 | 
				
			||||||
	// This will cause SSE not to work!!!
 | 
						// This will cause SSE not to work!!!
 | 
				
			||||||
	//server.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
						//server.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
				
			||||||
	server.Use(middleware.CORS())
 | 
						server.Use(middleware.RequestId())
 | 
				
			||||||
 | 
						middleware.SetUpLogger(server)
 | 
				
			||||||
	// Initialize session store
 | 
						// Initialize session store
 | 
				
			||||||
	store := cookie.NewStore([]byte(common.SessionSecret))
 | 
						store := cookie.NewStore([]byte(common.SessionSecret))
 | 
				
			||||||
	server.Use(sessions.Sessions("session", store))
 | 
						server.Use(sessions.Sessions("session", store))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -91,23 +91,16 @@ func TokenAuth() func(c *gin.Context) {
 | 
				
			|||||||
		key = parts[0]
 | 
							key = parts[0]
 | 
				
			||||||
		token, err := model.ValidateUserToken(key)
 | 
							token, err := model.ValidateUserToken(key)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			c.JSON(http.StatusUnauthorized, gin.H{
 | 
								abortWithMessage(c, http.StatusUnauthorized, err.Error())
 | 
				
			||||||
				"error": gin.H{
 | 
					 | 
				
			||||||
					"message": err.Error(),
 | 
					 | 
				
			||||||
					"type":    "one_api_error",
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			})
 | 
					 | 
				
			||||||
			c.Abort()
 | 
					 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if !model.CacheIsUserEnabled(token.UserId) {
 | 
							userEnabled, err := model.CacheIsUserEnabled(token.UserId)
 | 
				
			||||||
			c.JSON(http.StatusForbidden, gin.H{
 | 
							if err != nil {
 | 
				
			||||||
				"error": gin.H{
 | 
								abortWithMessage(c, http.StatusInternalServerError, err.Error())
 | 
				
			||||||
					"message": "用户已被封禁",
 | 
								return
 | 
				
			||||||
					"type":    "one_api_error",
 | 
							}
 | 
				
			||||||
				},
 | 
							if !userEnabled {
 | 
				
			||||||
			})
 | 
								abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 | 
				
			||||||
			c.Abort()
 | 
					 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.Set("id", token.UserId)
 | 
							c.Set("id", token.UserId)
 | 
				
			||||||
@@ -123,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
				
			|||||||
			if model.IsAdmin(token.UserId) {
 | 
								if model.IsAdmin(token.UserId) {
 | 
				
			||||||
				c.Set("channelId", parts[1])
 | 
									c.Set("channelId", parts[1])
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
									abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
				
			||||||
					"error": gin.H{
 | 
					 | 
				
			||||||
						"message": "普通用户不支持指定渠道",
 | 
					 | 
				
			||||||
						"type":    "one_api_error",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				c.Abort()
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,48 +25,27 @@ func Distribute() func(c *gin.Context) {
 | 
				
			|||||||
		if ok {
 | 
							if ok {
 | 
				
			||||||
			id, err := strconv.Atoi(channelId.(string))
 | 
								id, err := strconv.Atoi(channelId.(string))
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
									abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 | 
				
			||||||
					"error": gin.H{
 | 
					 | 
				
			||||||
						"message": "无效的渠道 ID",
 | 
					 | 
				
			||||||
						"type":    "one_api_error",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				c.Abort()
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			channel, err = model.GetChannelById(id, true)
 | 
								channel, err = model.GetChannelById(id, true)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
									abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 | 
				
			||||||
					"error": gin.H{
 | 
					 | 
				
			||||||
						"message": "无效的渠道 ID",
 | 
					 | 
				
			||||||
						"type":    "one_api_error",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				c.Abort()
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
								if channel.Status != common.ChannelStatusEnabled {
 | 
				
			||||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
									abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
				
			||||||
					"error": gin.H{
 | 
					 | 
				
			||||||
						"message": "该渠道已被禁用",
 | 
					 | 
				
			||||||
						"type":    "one_api_error",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				c.Abort()
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			// Select a channel for the user
 | 
								// Select a channel for the user
 | 
				
			||||||
			var modelRequest ModelRequest
 | 
								var modelRequest ModelRequest
 | 
				
			||||||
			err := common.UnmarshalBodyReusable(c, &modelRequest)
 | 
								var err error
 | 
				
			||||||
 | 
								if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
				
			||||||
 | 
									err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
									abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 | 
				
			||||||
					"error": gin.H{
 | 
					 | 
				
			||||||
						"message": "无效的请求",
 | 
					 | 
				
			||||||
						"type":    "one_api_error",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				c.Abort()
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
								if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
				
			||||||
@@ -84,6 +63,11 @@ func Distribute() func(c *gin.Context) {
 | 
				
			|||||||
					modelRequest.Model = "dall-e"
 | 
										modelRequest.Model = "dall-e"
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
				
			||||||
 | 
									if modelRequest.Model == "" {
 | 
				
			||||||
 | 
										modelRequest.Model = "whisper-1"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
								channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
									message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
				
			||||||
@@ -91,24 +75,23 @@ func Distribute() func(c *gin.Context) {
 | 
				
			|||||||
					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
										common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
				
			||||||
					message = "数据库一致性已被破坏,请联系管理员"
 | 
										message = "数据库一致性已被破坏,请联系管理员"
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				c.JSON(http.StatusServiceUnavailable, gin.H{
 | 
									abortWithMessage(c, http.StatusServiceUnavailable, message)
 | 
				
			||||||
					"error": gin.H{
 | 
					 | 
				
			||||||
						"message": message,
 | 
					 | 
				
			||||||
						"type":    "one_api_error",
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				c.Abort()
 | 
					 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.Set("channel", channel.Type)
 | 
							c.Set("channel", channel.Type)
 | 
				
			||||||
		c.Set("channel_id", channel.Id)
 | 
							c.Set("channel_id", channel.Id)
 | 
				
			||||||
		c.Set("channel_name", channel.Name)
 | 
							c.Set("channel_name", channel.Name)
 | 
				
			||||||
		c.Set("model_mapping", channel.ModelMapping)
 | 
							c.Set("model_mapping", channel.GetModelMapping())
 | 
				
			||||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
							c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
				
			||||||
		c.Set("base_url", channel.BaseURL)
 | 
							c.Set("base_url", channel.GetBaseURL())
 | 
				
			||||||
		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
 | 
							switch channel.Type {
 | 
				
			||||||
 | 
							case common.ChannelTypeAzure:
 | 
				
			||||||
			c.Set("api_version", channel.Other)
 | 
								c.Set("api_version", channel.Other)
 | 
				
			||||||
 | 
							case common.ChannelTypeXunfei:
 | 
				
			||||||
 | 
								c.Set("api_version", channel.Other)
 | 
				
			||||||
 | 
							case common.ChannelTypeAIProxyLibrary:
 | 
				
			||||||
 | 
								c.Set("library_id", channel.Other)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.Next()
 | 
							c.Next()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					package middleware
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SetUpLogger(server *gin.Engine) {
 | 
				
			||||||
 | 
						server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 | 
				
			||||||
 | 
							var requestID string
 | 
				
			||||||
 | 
							if param.Keys != nil {
 | 
				
			||||||
 | 
								requestID = param.Keys[common.RequestIdKey].(string)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
 | 
				
			||||||
 | 
								param.TimeStamp.Format("2006/01/02 - 15:04:05"),
 | 
				
			||||||
 | 
								requestID,
 | 
				
			||||||
 | 
								param.StatusCode,
 | 
				
			||||||
 | 
								param.Latency,
 | 
				
			||||||
 | 
								param.ClientIP,
 | 
				
			||||||
 | 
								param.Method,
 | 
				
			||||||
 | 
								param.Path,
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
				
			|||||||
 | 
					package middleware
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func RequestId() func(c *gin.Context) {
 | 
				
			||||||
 | 
						return func(c *gin.Context) {
 | 
				
			||||||
 | 
							id := common.GetTimeString() + common.GetRandomString(8)
 | 
				
			||||||
 | 
							c.Set(common.RequestIdKey, id)
 | 
				
			||||||
 | 
							ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
 | 
				
			||||||
 | 
							c.Request = c.Request.WithContext(ctx)
 | 
				
			||||||
 | 
							c.Header(common.RequestIdKey, id)
 | 
				
			||||||
 | 
							c.Next()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					package middleware
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
				
			||||||
 | 
						c.JSON(statusCode, gin.H{
 | 
				
			||||||
 | 
							"error": gin.H{
 | 
				
			||||||
 | 
								"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
 | 
				
			||||||
 | 
								"type":    "one_api_error",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						c.Abort()
 | 
				
			||||||
 | 
						common.LogError(c.Request.Context(), message)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -10,15 +10,18 @@ type Ability struct {
 | 
				
			|||||||
	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 | 
						Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 | 
				
			||||||
	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 | 
						ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 | 
				
			||||||
	Enabled   bool   `json:"enabled"`
 | 
						Enabled   bool   `json:"enabled"`
 | 
				
			||||||
 | 
						Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
					func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
				
			||||||
	ability := Ability{}
 | 
						ability := Ability{}
 | 
				
			||||||
	var err error = nil
 | 
						var err error = nil
 | 
				
			||||||
 | 
						maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
 | 
				
			||||||
 | 
						channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
 | 
				
			||||||
	if common.UsingSQLite {
 | 
						if common.UsingSQLite {
 | 
				
			||||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
 | 
							err = channelQuery.Order("RANDOM()").First(&ability).Error
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
 | 
							err = channelQuery.Order("RAND()").First(&ability).Error
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@@ -40,6 +43,7 @@ func (channel *Channel) AddAbilities() error {
 | 
				
			|||||||
				Model:     model,
 | 
									Model:     model,
 | 
				
			||||||
				ChannelId: channel.Id,
 | 
									ChannelId: channel.Id,
 | 
				
			||||||
				Enabled:   channel.Status == common.ChannelStatusEnabled,
 | 
									Enabled:   channel.Status == common.ChannelStatusEnabled,
 | 
				
			||||||
 | 
									Priority:  channel.Priority,
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			abilities = append(abilities, ability)
 | 
								abilities = append(abilities, ability)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"math/rand"
 | 
						"math/rand"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"sort"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
@@ -103,23 +104,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
 | 
				
			|||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func CacheIsUserEnabled(userId int) bool {
 | 
					func CacheIsUserEnabled(userId int) (bool, error) {
 | 
				
			||||||
	if !common.RedisEnabled {
 | 
						if !common.RedisEnabled {
 | 
				
			||||||
		return IsUserEnabled(userId)
 | 
							return IsUserEnabled(userId)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
 | 
						enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
 | 
				
			||||||
	if err != nil {
 | 
						if err == nil {
 | 
				
			||||||
		status := common.UserStatusDisabled
 | 
							return enabled == "1", nil
 | 
				
			||||||
		if IsUserEnabled(userId) {
 | 
					 | 
				
			||||||
			status = common.UserStatusEnabled
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		enabled = fmt.Sprintf("%d", status)
 | 
					 | 
				
			||||||
		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			common.SysError("Redis set user enabled error: " + err.Error())
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return enabled == "1"
 | 
					
 | 
				
			||||||
 | 
						userEnabled, err := IsUserEnabled(userId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						enabled = "0"
 | 
				
			||||||
 | 
						if userEnabled {
 | 
				
			||||||
 | 
							enabled = "1"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.SysError("Redis set user enabled error: " + err.Error())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return userEnabled, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var group2model2channels map[string]map[string][]*Channel
 | 
					var group2model2channels map[string]map[string][]*Channel
 | 
				
			||||||
@@ -154,6 +160,17 @@ func InitChannelCache() {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// sort by priority
 | 
				
			||||||
 | 
						for group, model2channels := range newGroup2model2channels {
 | 
				
			||||||
 | 
							for model, channels := range model2channels {
 | 
				
			||||||
 | 
								sort.Slice(channels, func(i, j int) bool {
 | 
				
			||||||
 | 
									return channels[i].GetPriority() > channels[j].GetPriority()
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								newGroup2model2channels[group][model] = channels
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	channelSyncLock.Lock()
 | 
						channelSyncLock.Lock()
 | 
				
			||||||
	group2model2channels = newGroup2model2channels
 | 
						group2model2channels = newGroup2model2channels
 | 
				
			||||||
	channelSyncLock.Unlock()
 | 
						channelSyncLock.Unlock()
 | 
				
			||||||
@@ -169,7 +186,7 @@ func SyncChannelCache(frequency int) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
					func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
				
			||||||
	if !common.RedisEnabled {
 | 
						if !common.MemoryCacheEnabled {
 | 
				
			||||||
		return GetRandomSatisfiedChannel(group, model)
 | 
							return GetRandomSatisfiedChannel(group, model)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	channelSyncLock.RLock()
 | 
						channelSyncLock.RLock()
 | 
				
			||||||
@@ -178,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 | 
				
			|||||||
	if len(channels) == 0 {
 | 
						if len(channels) == 0 {
 | 
				
			||||||
		return nil, errors.New("channel not found")
 | 
							return nil, errors.New("channel not found")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	idx := rand.Intn(len(channels))
 | 
						endIdx := len(channels)
 | 
				
			||||||
 | 
						// choose by priority
 | 
				
			||||||
 | 
						firstChannel := channels[0]
 | 
				
			||||||
 | 
						if firstChannel.GetPriority() > 0 {
 | 
				
			||||||
 | 
							for i := range channels {
 | 
				
			||||||
 | 
								if channels[i].GetPriority() != firstChannel.GetPriority() {
 | 
				
			||||||
 | 
									endIdx = i
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						idx := rand.Intn(endIdx)
 | 
				
			||||||
	return channels[idx], nil
 | 
						return channels[idx], nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,14 +15,15 @@ type Channel struct {
 | 
				
			|||||||
	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 | 
						CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 | 
				
			||||||
	TestTime           int64   `json:"test_time" gorm:"bigint"`
 | 
						TestTime           int64   `json:"test_time" gorm:"bigint"`
 | 
				
			||||||
	ResponseTime       int     `json:"response_time"` // in milliseconds
 | 
						ResponseTime       int     `json:"response_time"` // in milliseconds
 | 
				
			||||||
	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
 | 
						BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"`
 | 
				
			||||||
	Other              string  `json:"other"`
 | 
						Other              string  `json:"other"`
 | 
				
			||||||
	Balance            float64 `json:"balance"` // in USD
 | 
						Balance            float64 `json:"balance"` // in USD
 | 
				
			||||||
	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
						BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
				
			||||||
	Models             string  `json:"models"`
 | 
						Models             string  `json:"models"`
 | 
				
			||||||
	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
						Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
				
			||||||
	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 | 
						UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 | 
				
			||||||
	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
						ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
				
			||||||
 | 
						Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
					func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
				
			||||||
@@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (channel *Channel) GetPriority() int64 {
 | 
				
			||||||
 | 
						if channel.Priority == nil {
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return *channel.Priority
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (channel *Channel) GetBaseURL() string {
 | 
				
			||||||
 | 
						if channel.BaseURL == nil {
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return *channel.BaseURL
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (channel *Channel) GetModelMapping() string {
 | 
				
			||||||
 | 
						if channel.ModelMapping == nil {
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return *channel.ModelMapping
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (channel *Channel) Insert() error {
 | 
					func (channel *Channel) Insert() error {
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	err = DB.Create(channel).Error
 | 
						err = DB.Create(channel).Error
 | 
				
			||||||
@@ -141,6 +163,14 @@ func UpdateChannelStatusById(id int, status int) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func UpdateChannelUsedQuota(id int, quota int) {
 | 
					func UpdateChannelUsedQuota(id int, quota int) {
 | 
				
			||||||
 | 
						if common.BatchUpdateEnabled {
 | 
				
			||||||
 | 
							addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						updateChannelUsedQuota(id, quota)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func updateChannelUsedQuota(id int, quota int) {
 | 
				
			||||||
	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 | 
						err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		common.SysError("failed to update channel used quota: " + err.Error())
 | 
							common.SysError("failed to update channel used quota: " + err.Error())
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										29
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								model/log.go
									
									
									
									
									
								
							@@ -1,6 +1,8 @@
 | 
				
			|||||||
package model
 | 
					package model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -17,6 +19,7 @@ type Log struct {
 | 
				
			|||||||
	Quota            int    `json:"quota" gorm:"default:0"`
 | 
						Quota            int    `json:"quota" gorm:"default:0"`
 | 
				
			||||||
	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 | 
						PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 | 
				
			||||||
	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
 | 
						CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
 | 
				
			||||||
 | 
						Channel          int    `json:"channel" gorm:"default:0"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@@ -44,7 +47,9 @@ func RecordLog(userId int, logType int, content string) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
					
 | 
				
			||||||
 | 
					func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
				
			||||||
 | 
						common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
				
			||||||
	if !common.LogConsumeEnabled {
 | 
						if !common.LogConsumeEnabled {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -59,14 +64,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
 | 
				
			|||||||
		TokenName:        tokenName,
 | 
							TokenName:        tokenName,
 | 
				
			||||||
		ModelName:        modelName,
 | 
							ModelName:        modelName,
 | 
				
			||||||
		Quota:            quota,
 | 
							Quota:            quota,
 | 
				
			||||||
 | 
							Channel:          channelId,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err := DB.Create(log).Error
 | 
						err := DB.Create(log).Error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		common.SysError("failed to record log: " + err.Error())
 | 
							common.LogError(ctx, "failed to record log: "+err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
 | 
					func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 | 
				
			||||||
	var tx *gorm.DB
 | 
						var tx *gorm.DB
 | 
				
			||||||
	if logType == LogTypeUnknown {
 | 
						if logType == LogTypeUnknown {
 | 
				
			||||||
		tx = DB
 | 
							tx = DB
 | 
				
			||||||
@@ -88,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 | 
				
			|||||||
	if endTimestamp != 0 {
 | 
						if endTimestamp != 0 {
 | 
				
			||||||
		tx = tx.Where("created_at <= ?", endTimestamp)
 | 
							tx = tx.Where("created_at <= ?", endTimestamp)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if channel != 0 {
 | 
				
			||||||
 | 
							tx = tx.Where("channel = ?", channel)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 | 
						err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 | 
				
			||||||
	return logs, err
 | 
						return logs, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -125,8 +134,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 | 
				
			|||||||
	return logs, err
 | 
						return logs, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
 | 
					func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
 | 
				
			||||||
	tx := DB.Table("logs").Select("sum(quota)")
 | 
						tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
 | 
				
			||||||
	if username != "" {
 | 
						if username != "" {
 | 
				
			||||||
		tx = tx.Where("username = ?", username)
 | 
							tx = tx.Where("username = ?", username)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -142,12 +151,15 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
				
			|||||||
	if modelName != "" {
 | 
						if modelName != "" {
 | 
				
			||||||
		tx = tx.Where("model_name = ?", modelName)
 | 
							tx = tx.Where("model_name = ?", modelName)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if channel != 0 {
 | 
				
			||||||
 | 
							tx = tx.Where("channel = ?", channel)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	tx.Where("type = ?", LogTypeConsume).Scan("a)
 | 
						tx.Where("type = ?", LogTypeConsume).Scan("a)
 | 
				
			||||||
	return quota
 | 
						return quota
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
 | 
					func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
 | 
				
			||||||
	tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
 | 
						tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
 | 
				
			||||||
	if username != "" {
 | 
						if username != "" {
 | 
				
			||||||
		tx = tx.Where("username = ?", username)
 | 
							tx = tx.Where("username = ?", username)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -166,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
				
			|||||||
	tx.Where("type = ?", LogTypeConsume).Scan(&token)
 | 
						tx.Where("type = ?", LogTypeConsume).Scan(&token)
 | 
				
			||||||
	return token
 | 
						return token
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func DeleteOldLog(targetTimestamp int64) (int64, error) {
 | 
				
			||||||
 | 
						result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
 | 
				
			||||||
 | 
						return result.RowsAffected, result.Error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	token, err = CacheGetTokenByKey(key)
 | 
						token, err = CacheGetTokenByKey(key)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
 | 
							if token.Status == common.TokenStatusExhausted {
 | 
				
			||||||
 | 
								return nil, errors.New("该令牌额度已用尽")
 | 
				
			||||||
 | 
							} else if token.Status == common.TokenStatusExpired {
 | 
				
			||||||
 | 
								return nil, errors.New("该令牌已过期")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		if token.Status != common.TokenStatusEnabled {
 | 
							if token.Status != common.TokenStatusEnabled {
 | 
				
			||||||
			return nil, errors.New("该令牌状态不可用")
 | 
								return nil, errors.New("该令牌状态不可用")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
 | 
							if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
 | 
				
			||||||
			token.Status = common.TokenStatusExpired
 | 
								if !common.RedisEnabled {
 | 
				
			||||||
			err := token.SelectUpdate()
 | 
									token.Status = common.TokenStatusExpired
 | 
				
			||||||
			if err != nil {
 | 
									err := token.SelectUpdate()
 | 
				
			||||||
				common.SysError("failed to update token status" + err.Error())
 | 
									if err != nil {
 | 
				
			||||||
 | 
										common.SysError("failed to update token status" + err.Error())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			return nil, errors.New("该令牌已过期")
 | 
								return nil, errors.New("该令牌已过期")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if !token.UnlimitedQuota && token.RemainQuota <= 0 {
 | 
							if !token.UnlimitedQuota && token.RemainQuota <= 0 {
 | 
				
			||||||
			token.Status = common.TokenStatusExhausted
 | 
								if !common.RedisEnabled {
 | 
				
			||||||
			err := token.SelectUpdate()
 | 
									// in this case, we can make sure the token is exhausted
 | 
				
			||||||
			if err != nil {
 | 
									token.Status = common.TokenStatusExhausted
 | 
				
			||||||
				common.SysError("failed to update token status" + err.Error())
 | 
									err := token.SelectUpdate()
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										common.SysError("failed to update token status" + err.Error())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			return nil, errors.New("该令牌额度已用尽")
 | 
								return nil, errors.New("该令牌额度已用尽")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		go func() {
 | 
					 | 
				
			||||||
			token.AccessedTime = common.GetTimestamp()
 | 
					 | 
				
			||||||
			err := token.SelectUpdate()
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				common.SysError("failed to update token" + err.Error())
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}()
 | 
					 | 
				
			||||||
		return token, nil
 | 
							return token, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, errors.New("无效的令牌")
 | 
						return nil, errors.New("无效的令牌")
 | 
				
			||||||
@@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
				
			|||||||
	if quota < 0 {
 | 
						if quota < 0 {
 | 
				
			||||||
		return errors.New("quota 不能为负数!")
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if common.BatchUpdateEnabled {
 | 
				
			||||||
 | 
							addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return increaseTokenQuota(id, quota)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func increaseTokenQuota(id int, quota int) (err error) {
 | 
				
			||||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
						err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
				
			||||||
		map[string]interface{}{
 | 
							map[string]interface{}{
 | 
				
			||||||
			"remain_quota": gorm.Expr("remain_quota + ?", quota),
 | 
								"remain_quota":  gorm.Expr("remain_quota + ?", quota),
 | 
				
			||||||
			"used_quota":   gorm.Expr("used_quota - ?", quota),
 | 
								"used_quota":    gorm.Expr("used_quota - ?", quota),
 | 
				
			||||||
 | 
								"accessed_time": common.GetTimestamp(),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	).Error
 | 
						).Error
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
@@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
				
			|||||||
	if quota < 0 {
 | 
						if quota < 0 {
 | 
				
			||||||
		return errors.New("quota 不能为负数!")
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if common.BatchUpdateEnabled {
 | 
				
			||||||
 | 
							addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return decreaseTokenQuota(id, quota)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func decreaseTokenQuota(id int, quota int) (err error) {
 | 
				
			||||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
						err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
				
			||||||
		map[string]interface{}{
 | 
							map[string]interface{}{
 | 
				
			||||||
			"remain_quota": gorm.Expr("remain_quota - ?", quota),
 | 
								"remain_quota":  gorm.Expr("remain_quota - ?", quota),
 | 
				
			||||||
			"used_quota":   gorm.Expr("used_quota + ?", quota),
 | 
								"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
				
			||||||
 | 
								"accessed_time": common.GetTimestamp(),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	).Error
 | 
						).Error
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
 | 
				
			|||||||
	return user.Role >= common.RoleAdminUser
 | 
						return user.Role >= common.RoleAdminUser
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func IsUserEnabled(userId int) bool {
 | 
					func IsUserEnabled(userId int) (bool, error) {
 | 
				
			||||||
	if userId == 0 {
 | 
						if userId == 0 {
 | 
				
			||||||
		return false
 | 
							return false, errors.New("user id is empty")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var user User
 | 
						var user User
 | 
				
			||||||
	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 | 
						err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		common.SysError("no such user " + err.Error())
 | 
							return false, err
 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return user.Status == common.UserStatusEnabled
 | 
						return user.Status == common.UserStatusEnabled, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ValidateAccessToken(token string) (user *User) {
 | 
					func ValidateAccessToken(token string) (user *User) {
 | 
				
			||||||
@@ -275,6 +274,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 | 
				
			|||||||
	if quota < 0 {
 | 
						if quota < 0 {
 | 
				
			||||||
		return errors.New("quota 不能为负数!")
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if common.BatchUpdateEnabled {
 | 
				
			||||||
 | 
							addNewRecord(BatchUpdateTypeUserQuota, id, quota)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return increaseUserQuota(id, quota)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func increaseUserQuota(id int, quota int) (err error) {
 | 
				
			||||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
						err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -283,6 +290,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 | 
				
			|||||||
	if quota < 0 {
 | 
						if quota < 0 {
 | 
				
			||||||
		return errors.New("quota 不能为负数!")
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if common.BatchUpdateEnabled {
 | 
				
			||||||
 | 
							addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return decreaseUserQuota(id, quota)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func decreaseUserQuota(id int, quota int) (err error) {
 | 
				
			||||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
						err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
					func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
				
			||||||
 | 
						if common.BatchUpdateEnabled {
 | 
				
			||||||
 | 
							addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						updateUserUsedQuotaAndRequestCount(id, quota, 1)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 | 
				
			||||||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
						err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
				
			||||||
		map[string]interface{}{
 | 
							map[string]interface{}{
 | 
				
			||||||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
								"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
				
			||||||
			"request_count": gorm.Expr("request_count + ?", 1),
 | 
								"request_count": gorm.Expr("request_count + ?", count),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	).Error
 | 
						).Error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,75 @@
 | 
				
			|||||||
 | 
					package model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						BatchUpdateTypeUserQuota = iota
 | 
				
			||||||
 | 
						BatchUpdateTypeTokenQuota
 | 
				
			||||||
 | 
						BatchUpdateTypeUsedQuotaAndRequestCount
 | 
				
			||||||
 | 
						BatchUpdateTypeChannelUsedQuota
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var batchUpdateStores []map[int]int
 | 
				
			||||||
 | 
					var batchUpdateLocks []sync.Mutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
				
			||||||
 | 
							batchUpdateStores = append(batchUpdateStores, make(map[int]int))
 | 
				
			||||||
 | 
							batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func InitBatchUpdater() {
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
 | 
				
			||||||
 | 
								batchUpdate()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func addNewRecord(type_ int, id int, value int) {
 | 
				
			||||||
 | 
						batchUpdateLocks[type_].Lock()
 | 
				
			||||||
 | 
						defer batchUpdateLocks[type_].Unlock()
 | 
				
			||||||
 | 
						if _, ok := batchUpdateStores[type_][id]; !ok {
 | 
				
			||||||
 | 
							batchUpdateStores[type_][id] = value
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							batchUpdateStores[type_][id] += value
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func batchUpdate() {
 | 
				
			||||||
 | 
						common.SysLog("batch update started")
 | 
				
			||||||
 | 
						for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
				
			||||||
 | 
							batchUpdateLocks[i].Lock()
 | 
				
			||||||
 | 
							store := batchUpdateStores[i]
 | 
				
			||||||
 | 
							batchUpdateStores[i] = make(map[int]int)
 | 
				
			||||||
 | 
							batchUpdateLocks[i].Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for key, value := range store {
 | 
				
			||||||
 | 
								switch i {
 | 
				
			||||||
 | 
								case BatchUpdateTypeUserQuota:
 | 
				
			||||||
 | 
									err := increaseUserQuota(key, value)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										common.SysError("failed to batch update user quota: " + err.Error())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case BatchUpdateTypeTokenQuota:
 | 
				
			||||||
 | 
									err := increaseTokenQuota(key, value)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										common.SysError("failed to batch update token quota: " + err.Error())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case BatchUpdateTypeUsedQuotaAndRequestCount:
 | 
				
			||||||
 | 
									updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
 | 
				
			||||||
 | 
								case BatchUpdateTypeChannelUsedQuota:
 | 
				
			||||||
 | 
									updateChannelUsedQuota(key, value)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						common.SysLog("batch update finished")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
				
			|||||||
		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
 | 
							apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
 | 
				
			||||||
		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
 | 
							apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
 | 
				
			||||||
		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
 | 
							apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
 | 
				
			||||||
 | 
							apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
 | 
				
			||||||
		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
 | 
							apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
 | 
				
			||||||
		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
 | 
							apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
 | 
				
			||||||
		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
 | 
							apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
 | 
				
			||||||
@@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		logRoute := apiRouter.Group("/log")
 | 
							logRoute := apiRouter.Group("/log")
 | 
				
			||||||
		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
 | 
							logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
 | 
				
			||||||
 | 
							logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
 | 
				
			||||||
		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
 | 
							logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
 | 
				
			||||||
		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
 | 
							logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
 | 
				
			||||||
		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 | 
							logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SetRelayRouter(router *gin.Engine) {
 | 
					func SetRelayRouter(router *gin.Engine) {
 | 
				
			||||||
 | 
						router.Use(middleware.CORS())
 | 
				
			||||||
	// https://platform.openai.com/docs/api-reference/introduction
 | 
						// https://platform.openai.com/docs/api-reference/introduction
 | 
				
			||||||
	modelsRouter := router.Group("/v1/models")
 | 
						modelsRouter := router.Group("/v1/models")
 | 
				
			||||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
						modelsRouter.Use(middleware.TokenAuth())
 | 
				
			||||||
@@ -26,8 +27,8 @@ func SetRelayRouter(router *gin.Engine) {
 | 
				
			|||||||
		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 | 
							relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 | 
				
			||||||
		relayV1Router.POST("/embeddings", controller.Relay)
 | 
							relayV1Router.POST("/embeddings", controller.Relay)
 | 
				
			||||||
		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 | 
							relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 | 
				
			||||||
		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
 | 
							relayV1Router.POST("/audio/transcriptions", controller.Relay)
 | 
				
			||||||
		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
 | 
							relayV1Router.POST("/audio/translations", controller.Relay)
 | 
				
			||||||
		relayV1Router.GET("/files", controller.RelayNotImplemented)
 | 
							relayV1Router.GET("/files", controller.RelayNotImplemented)
 | 
				
			||||||
		relayV1Router.POST("/files", controller.RelayNotImplemented)
 | 
							relayV1Router.POST("/files", controller.RelayNotImplemented)
 | 
				
			||||||
		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
 | 
							relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,7 @@
 | 
				
			|||||||
import React, { useEffect, useState } from 'react';
 | 
					import React, { useEffect, useState } from 'react';
 | 
				
			||||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
 | 
					import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
 | 
				
			||||||
import { Link } from 'react-router-dom';
 | 
					import { Link } from 'react-router-dom';
 | 
				
			||||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 | 
					import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 | 
					import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 | 
				
			||||||
import { renderGroup, renderNumber } from '../helpers/render';
 | 
					import { renderGroup, renderNumber } from '../helpers/render';
 | 
				
			||||||
@@ -24,7 +24,7 @@ function renderType(type) {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
 | 
					    type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
 | 
					  return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function renderBalance(type, balance) {
 | 
					function renderBalance(type, balance) {
 | 
				
			||||||
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
 | 
				
			|||||||
      });
 | 
					      });
 | 
				
			||||||
  }, []);
 | 
					  }, []);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const manageChannel = async (id, action, idx) => {
 | 
					  const manageChannel = async (id, action, idx, priority) => {
 | 
				
			||||||
    let data = { id };
 | 
					    let data = { id };
 | 
				
			||||||
    let res;
 | 
					    let res;
 | 
				
			||||||
    switch (action) {
 | 
					    switch (action) {
 | 
				
			||||||
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
 | 
				
			|||||||
        data.status = 2;
 | 
					        data.status = 2;
 | 
				
			||||||
        res = await API.put('/api/channel/', data);
 | 
					        res = await API.put('/api/channel/', data);
 | 
				
			||||||
        break;
 | 
					        break;
 | 
				
			||||||
 | 
					      case 'priority':
 | 
				
			||||||
 | 
					        if (priority === '') {
 | 
				
			||||||
 | 
					          return;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        data.priority = parseInt(priority);
 | 
				
			||||||
 | 
					        res = await API.put('/api/channel/', data);
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    const { success, message } = res.data;
 | 
					    const { success, message } = res.data;
 | 
				
			||||||
    if (success) {
 | 
					    if (success) {
 | 
				
			||||||
@@ -195,6 +202,7 @@ const ChannelsTable = () => {
 | 
				
			|||||||
      showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
 | 
					      showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      showError(message);
 | 
					      showError(message);
 | 
				
			||||||
 | 
					      showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -334,6 +342,14 @@ const ChannelsTable = () => {
 | 
				
			|||||||
            >
 | 
					            >
 | 
				
			||||||
              余额
 | 
					              余额
 | 
				
			||||||
            </Table.HeaderCell>
 | 
					            </Table.HeaderCell>
 | 
				
			||||||
 | 
					            <Table.HeaderCell
 | 
				
			||||||
 | 
					                style={{ cursor: 'pointer' }}
 | 
				
			||||||
 | 
					                onClick={() => {
 | 
				
			||||||
 | 
					                  sortChannel('priority');
 | 
				
			||||||
 | 
					                }}
 | 
				
			||||||
 | 
					            >
 | 
				
			||||||
 | 
					              优先级
 | 
				
			||||||
 | 
					            </Table.HeaderCell>
 | 
				
			||||||
            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
					            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
				
			||||||
          </Table.Row>
 | 
					          </Table.Row>
 | 
				
			||||||
        </Table.Header>
 | 
					        </Table.Header>
 | 
				
			||||||
@@ -372,6 +388,22 @@ const ChannelsTable = () => {
 | 
				
			|||||||
                      basic
 | 
					                      basic
 | 
				
			||||||
                    />
 | 
					                    />
 | 
				
			||||||
                  </Table.Cell>
 | 
					                  </Table.Cell>
 | 
				
			||||||
 | 
					                  <Table.Cell>
 | 
				
			||||||
 | 
					                    <Popup
 | 
				
			||||||
 | 
					                        trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => {
 | 
				
			||||||
 | 
					                          manageChannel(
 | 
				
			||||||
 | 
					                              channel.id,
 | 
				
			||||||
 | 
					                              'priority',
 | 
				
			||||||
 | 
					                              idx,
 | 
				
			||||||
 | 
					                              event.target.value,
 | 
				
			||||||
 | 
					                          );
 | 
				
			||||||
 | 
					                        }}>
 | 
				
			||||||
 | 
					                          <input style={{maxWidth:'60px'}} />
 | 
				
			||||||
 | 
					                        </Input>}
 | 
				
			||||||
 | 
					                        content='渠道选择优先级,越高越优先'
 | 
				
			||||||
 | 
					                        basic
 | 
				
			||||||
 | 
					                    />
 | 
				
			||||||
 | 
					                  </Table.Cell>
 | 
				
			||||||
                  <Table.Cell>
 | 
					                  <Table.Cell>
 | 
				
			||||||
                    <div>
 | 
					                    <div>
 | 
				
			||||||
                      <Button
 | 
					                      <Button
 | 
				
			||||||
@@ -440,7 +472,7 @@ const ChannelsTable = () => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        <Table.Footer>
 | 
					        <Table.Footer>
 | 
				
			||||||
          <Table.Row>
 | 
					          <Table.Row>
 | 
				
			||||||
            <Table.HeaderCell colSpan='8'>
 | 
					            <Table.HeaderCell colSpan='9'>
 | 
				
			||||||
              <Button size='small' as={Link} to='/channel/add' loading={loading}>
 | 
					              <Button size='small' as={Link} to='/channel/add' loading={loading}>
 | 
				
			||||||
                添加新的渠道
 | 
					                添加新的渠道
 | 
				
			||||||
              </Button>
 | 
					              </Button>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,8 +13,8 @@ const GitHubOAuth = () => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  let navigate = useNavigate();
 | 
					  let navigate = useNavigate();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const sendCode = async (code, count) => {
 | 
					  const sendCode = async (code, state, count) => {
 | 
				
			||||||
    const res = await API.get(`/api/oauth/github?code=${code}`);
 | 
					    const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
 | 
				
			||||||
    const { success, message, data } = res.data;
 | 
					    const { success, message, data } = res.data;
 | 
				
			||||||
    if (success) {
 | 
					    if (success) {
 | 
				
			||||||
      if (message === 'bind') {
 | 
					      if (message === 'bind') {
 | 
				
			||||||
@@ -36,13 +36,14 @@ const GitHubOAuth = () => {
 | 
				
			|||||||
      count++;
 | 
					      count++;
 | 
				
			||||||
      setPrompt(`出现错误,第 ${count} 次重试中...`);
 | 
					      setPrompt(`出现错误,第 ${count} 次重试中...`);
 | 
				
			||||||
      await new Promise((resolve) => setTimeout(resolve, count * 2000));
 | 
					      await new Promise((resolve) => setTimeout(resolve, count * 2000));
 | 
				
			||||||
      await sendCode(code, count);
 | 
					      await sendCode(code, state, count);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  useEffect(() => {
 | 
					  useEffect(() => {
 | 
				
			||||||
    let code = searchParams.get('code');
 | 
					    let code = searchParams.get('code');
 | 
				
			||||||
    sendCode(code, 0).then();
 | 
					    let state = searchParams.get('state');
 | 
				
			||||||
 | 
					    sendCode(code, state, 0).then();
 | 
				
			||||||
  }, []);
 | 
					  }, []);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return (
 | 
					  return (
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
 | 
				
			|||||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
 | 
					import { Link, useNavigate, useSearchParams } from 'react-router-dom';
 | 
				
			||||||
import { UserContext } from '../context/User';
 | 
					import { UserContext } from '../context/User';
 | 
				
			||||||
import { API, getLogo, showError, showSuccess } from '../helpers';
 | 
					import { API, getLogo, showError, showSuccess } from '../helpers';
 | 
				
			||||||
 | 
					import { getOAuthState, onGitHubOAuthClicked } from './utils';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const LoginForm = () => {
 | 
					const LoginForm = () => {
 | 
				
			||||||
  const [inputs, setInputs] = useState({
 | 
					  const [inputs, setInputs] = useState({
 | 
				
			||||||
@@ -31,12 +32,6 @@ const LoginForm = () => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
 | 
					  const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const onGitHubOAuthClicked = () => {
 | 
					 | 
				
			||||||
    window.open(
 | 
					 | 
				
			||||||
      `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
 | 
					 | 
				
			||||||
    );
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const onWeChatLoginClicked = () => {
 | 
					  const onWeChatLoginClicked = () => {
 | 
				
			||||||
    setShowWeChatLoginModal(true);
 | 
					    setShowWeChatLoginModal(true);
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
@@ -131,7 +126,7 @@ const LoginForm = () => {
 | 
				
			|||||||
                circular
 | 
					                circular
 | 
				
			||||||
                color='black'
 | 
					                color='black'
 | 
				
			||||||
                icon='github'
 | 
					                icon='github'
 | 
				
			||||||
                onClick={onGitHubOAuthClicked}
 | 
					                onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
 | 
				
			||||||
              />
 | 
					              />
 | 
				
			||||||
            ) : (
 | 
					            ) : (
 | 
				
			||||||
              <></>
 | 
					              <></>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -56,9 +56,10 @@ const LogsTable = () => {
 | 
				
			|||||||
    token_name: '',
 | 
					    token_name: '',
 | 
				
			||||||
    model_name: '',
 | 
					    model_name: '',
 | 
				
			||||||
    start_timestamp: timestamp2string(0),
 | 
					    start_timestamp: timestamp2string(0),
 | 
				
			||||||
    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
 | 
					    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
 | 
				
			||||||
 | 
					    channel: ''
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
  const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
 | 
					  const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const [stat, setStat] = useState({
 | 
					  const [stat, setStat] = useState({
 | 
				
			||||||
    quota: 0,
 | 
					    quota: 0,
 | 
				
			||||||
@@ -84,7 +85,7 @@ const LogsTable = () => {
 | 
				
			|||||||
  const getLogStat = async () => {
 | 
					  const getLogStat = async () => {
 | 
				
			||||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
					    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
				
			||||||
    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
					    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
				
			||||||
    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
 | 
					    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
 | 
				
			||||||
    const { success, message, data } = res.data;
 | 
					    const { success, message, data } = res.data;
 | 
				
			||||||
    if (success) {
 | 
					    if (success) {
 | 
				
			||||||
      setStat(data);
 | 
					      setStat(data);
 | 
				
			||||||
@@ -109,7 +110,7 @@ const LogsTable = () => {
 | 
				
			|||||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
					    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
				
			||||||
    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
					    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
				
			||||||
    if (isAdminUser) {
 | 
					    if (isAdminUser) {
 | 
				
			||||||
      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
					      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
					      url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -205,16 +206,9 @@ const LogsTable = () => {
 | 
				
			|||||||
        </Header>
 | 
					        </Header>
 | 
				
			||||||
        <Form>
 | 
					        <Form>
 | 
				
			||||||
          <Form.Group>
 | 
					          <Form.Group>
 | 
				
			||||||
            {
 | 
					            <Form.Input fluid label={'令牌名称'} width={3} value={token_name}
 | 
				
			||||||
              isAdminUser && (
 | 
					 | 
				
			||||||
                <Form.Input fluid label={'用户名称'} width={2} value={username}
 | 
					 | 
				
			||||||
                            placeholder={'可选值'} name='username'
 | 
					 | 
				
			||||||
                            onChange={handleInputChange} />
 | 
					 | 
				
			||||||
              )
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
 | 
					 | 
				
			||||||
                        placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
 | 
					                        placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
 | 
				
			||||||
            <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
 | 
					            <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
 | 
				
			||||||
                        name='model_name'
 | 
					                        name='model_name'
 | 
				
			||||||
                        onChange={handleInputChange} />
 | 
					                        onChange={handleInputChange} />
 | 
				
			||||||
            <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
 | 
					            <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
 | 
				
			||||||
@@ -225,6 +219,19 @@ const LogsTable = () => {
 | 
				
			|||||||
                        onChange={handleInputChange} />
 | 
					                        onChange={handleInputChange} />
 | 
				
			||||||
            <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
 | 
					            <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
 | 
				
			||||||
          </Form.Group>
 | 
					          </Form.Group>
 | 
				
			||||||
 | 
					          {
 | 
				
			||||||
 | 
					            isAdminUser && <>
 | 
				
			||||||
 | 
					              <Form.Group>
 | 
				
			||||||
 | 
					                <Form.Input fluid label={'渠道 ID'} width={3} value={channel}
 | 
				
			||||||
 | 
					                            placeholder='可选值' name='channel'
 | 
				
			||||||
 | 
					                            onChange={handleInputChange} />
 | 
				
			||||||
 | 
					                <Form.Input fluid label={'用户名称'} width={3} value={username}
 | 
				
			||||||
 | 
					                            placeholder={'可选值'} name='username'
 | 
				
			||||||
 | 
					                            onChange={handleInputChange} />
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					              </Form.Group>
 | 
				
			||||||
 | 
					            </>
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
        </Form>
 | 
					        </Form>
 | 
				
			||||||
        <Table basic compact size='small'>
 | 
					        <Table basic compact size='small'>
 | 
				
			||||||
          <Table.Header>
 | 
					          <Table.Header>
 | 
				
			||||||
@@ -238,6 +245,17 @@ const LogsTable = () => {
 | 
				
			|||||||
              >
 | 
					              >
 | 
				
			||||||
                时间
 | 
					                时间
 | 
				
			||||||
              </Table.HeaderCell>
 | 
					              </Table.HeaderCell>
 | 
				
			||||||
 | 
					              {
 | 
				
			||||||
 | 
					                isAdminUser && <Table.HeaderCell
 | 
				
			||||||
 | 
					                  style={{ cursor: 'pointer' }}
 | 
				
			||||||
 | 
					                  onClick={() => {
 | 
				
			||||||
 | 
					                    sortLog('channel');
 | 
				
			||||||
 | 
					                  }}
 | 
				
			||||||
 | 
					                  width={1}
 | 
				
			||||||
 | 
					                >
 | 
				
			||||||
 | 
					                  渠道
 | 
				
			||||||
 | 
					                </Table.HeaderCell>
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
              {
 | 
					              {
 | 
				
			||||||
                isAdminUser && <Table.HeaderCell
 | 
					                isAdminUser && <Table.HeaderCell
 | 
				
			||||||
                  style={{ cursor: 'pointer' }}
 | 
					                  style={{ cursor: 'pointer' }}
 | 
				
			||||||
@@ -299,16 +317,16 @@ const LogsTable = () => {
 | 
				
			|||||||
                onClick={() => {
 | 
					                onClick={() => {
 | 
				
			||||||
                  sortLog('quota');
 | 
					                  sortLog('quota');
 | 
				
			||||||
                }}
 | 
					                }}
 | 
				
			||||||
                width={2}
 | 
					                width={1}
 | 
				
			||||||
              >
 | 
					              >
 | 
				
			||||||
                消耗额度
 | 
					                额度
 | 
				
			||||||
              </Table.HeaderCell>
 | 
					              </Table.HeaderCell>
 | 
				
			||||||
              <Table.HeaderCell
 | 
					              <Table.HeaderCell
 | 
				
			||||||
                style={{ cursor: 'pointer' }}
 | 
					                style={{ cursor: 'pointer' }}
 | 
				
			||||||
                onClick={() => {
 | 
					                onClick={() => {
 | 
				
			||||||
                  sortLog('content');
 | 
					                  sortLog('content');
 | 
				
			||||||
                }}
 | 
					                }}
 | 
				
			||||||
                width={isAdminUser ? 4 : 5}
 | 
					                width={isAdminUser ? 4 : 6}
 | 
				
			||||||
              >
 | 
					              >
 | 
				
			||||||
                详情
 | 
					                详情
 | 
				
			||||||
              </Table.HeaderCell>
 | 
					              </Table.HeaderCell>
 | 
				
			||||||
@@ -324,8 +342,13 @@ const LogsTable = () => {
 | 
				
			|||||||
              .map((log, idx) => {
 | 
					              .map((log, idx) => {
 | 
				
			||||||
                if (log.deleted) return <></>;
 | 
					                if (log.deleted) return <></>;
 | 
				
			||||||
                return (
 | 
					                return (
 | 
				
			||||||
                  <Table.Row key={log.created_at}>
 | 
					                  <Table.Row key={log.id}>
 | 
				
			||||||
                    <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
 | 
					                    <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                      isAdminUser && (
 | 
				
			||||||
 | 
					                        <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
 | 
				
			||||||
 | 
					                      )
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
                    {
 | 
					                    {
 | 
				
			||||||
                      isAdminUser && (
 | 
					                      isAdminUser && (
 | 
				
			||||||
                        <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
 | 
					                        <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
 | 
				
			||||||
@@ -345,7 +368,7 @@ const LogsTable = () => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
          <Table.Footer>
 | 
					          <Table.Footer>
 | 
				
			||||||
            <Table.Row>
 | 
					            <Table.Row>
 | 
				
			||||||
              <Table.HeaderCell colSpan={'9'}>
 | 
					              <Table.HeaderCell colSpan={'10'}>
 | 
				
			||||||
                <Select
 | 
					                <Select
 | 
				
			||||||
                  placeholder='选择明细分类'
 | 
					                  placeholder='选择明细分类'
 | 
				
			||||||
                  options={LOG_OPTIONS}
 | 
					                  options={LOG_OPTIONS}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,8 +1,9 @@
 | 
				
			|||||||
import React, { useEffect, useState } from 'react';
 | 
					import React, { useEffect, useState } from 'react';
 | 
				
			||||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
 | 
					import { Divider, Form, Grid, Header } from 'semantic-ui-react';
 | 
				
			||||||
import { API, showError, verifyJSON } from '../helpers';
 | 
					import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const OperationSetting = () => {
 | 
					const OperationSetting = () => {
 | 
				
			||||||
 | 
					  let now = new Date();
 | 
				
			||||||
  let [inputs, setInputs] = useState({
 | 
					  let [inputs, setInputs] = useState({
 | 
				
			||||||
    QuotaForNewUser: 0,
 | 
					    QuotaForNewUser: 0,
 | 
				
			||||||
    QuotaForInviter: 0,
 | 
					    QuotaForInviter: 0,
 | 
				
			||||||
@@ -20,10 +21,11 @@ const OperationSetting = () => {
 | 
				
			|||||||
    DisplayInCurrencyEnabled: '',
 | 
					    DisplayInCurrencyEnabled: '',
 | 
				
			||||||
    DisplayTokenStatEnabled: '',
 | 
					    DisplayTokenStatEnabled: '',
 | 
				
			||||||
    ApproximateTokenEnabled: '',
 | 
					    ApproximateTokenEnabled: '',
 | 
				
			||||||
    RetryTimes: 0,
 | 
					    RetryTimes: 0
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
  const [originInputs, setOriginInputs] = useState({});
 | 
					  const [originInputs, setOriginInputs] = useState({});
 | 
				
			||||||
  let [loading, setLoading] = useState(false);
 | 
					  let [loading, setLoading] = useState(false);
 | 
				
			||||||
 | 
					  let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const getOptions = async () => {
 | 
					  const getOptions = async () => {
 | 
				
			||||||
    const res = await API.get('/api/option/');
 | 
					    const res = await API.get('/api/option/');
 | 
				
			||||||
@@ -130,6 +132,17 @@ const OperationSetting = () => {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const deleteHistoryLogs = async () => {
 | 
				
			||||||
 | 
					    console.log(inputs);
 | 
				
			||||||
 | 
					    const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
 | 
				
			||||||
 | 
					    const { success, message, data } = res.data;
 | 
				
			||||||
 | 
					    if (success) {
 | 
				
			||||||
 | 
					      showSuccess(`${data} 条日志已清理!`);
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    showError('日志清理失败:' + message);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return (
 | 
					  return (
 | 
				
			||||||
    <Grid columns={1}>
 | 
					    <Grid columns={1}>
 | 
				
			||||||
      <Grid.Column>
 | 
					      <Grid.Column>
 | 
				
			||||||
@@ -179,12 +192,6 @@ const OperationSetting = () => {
 | 
				
			|||||||
            />
 | 
					            />
 | 
				
			||||||
          </Form.Group>
 | 
					          </Form.Group>
 | 
				
			||||||
          <Form.Group inline>
 | 
					          <Form.Group inline>
 | 
				
			||||||
            <Form.Checkbox
 | 
					 | 
				
			||||||
              checked={inputs.LogConsumeEnabled === 'true'}
 | 
					 | 
				
			||||||
              label='启用额度消费日志记录'
 | 
					 | 
				
			||||||
              name='LogConsumeEnabled'
 | 
					 | 
				
			||||||
              onChange={handleInputChange}
 | 
					 | 
				
			||||||
            />
 | 
					 | 
				
			||||||
            <Form.Checkbox
 | 
					            <Form.Checkbox
 | 
				
			||||||
              checked={inputs.DisplayInCurrencyEnabled === 'true'}
 | 
					              checked={inputs.DisplayInCurrencyEnabled === 'true'}
 | 
				
			||||||
              label='以货币形式显示额度'
 | 
					              label='以货币形式显示额度'
 | 
				
			||||||
@@ -208,6 +215,28 @@ const OperationSetting = () => {
 | 
				
			|||||||
            submitConfig('general').then();
 | 
					            submitConfig('general').then();
 | 
				
			||||||
          }}>保存通用设置</Form.Button>
 | 
					          }}>保存通用设置</Form.Button>
 | 
				
			||||||
          <Divider />
 | 
					          <Divider />
 | 
				
			||||||
 | 
					          <Header as='h3'>
 | 
				
			||||||
 | 
					            日志设置
 | 
				
			||||||
 | 
					          </Header>
 | 
				
			||||||
 | 
					          <Form.Group inline>
 | 
				
			||||||
 | 
					            <Form.Checkbox
 | 
				
			||||||
 | 
					              checked={inputs.LogConsumeEnabled === 'true'}
 | 
				
			||||||
 | 
					              label='启用额度消费日志记录'
 | 
				
			||||||
 | 
					              name='LogConsumeEnabled'
 | 
				
			||||||
 | 
					              onChange={handleInputChange}
 | 
				
			||||||
 | 
					            />
 | 
				
			||||||
 | 
					          </Form.Group>
 | 
				
			||||||
 | 
					          <Form.Group widths={4}>
 | 
				
			||||||
 | 
					            <Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
 | 
				
			||||||
 | 
					                        name='history_timestamp'
 | 
				
			||||||
 | 
					                        onChange={(e, { name, value }) => {
 | 
				
			||||||
 | 
					                          setHistoryTimestamp(value);
 | 
				
			||||||
 | 
					                        }} />
 | 
				
			||||||
 | 
					          </Form.Group>
 | 
				
			||||||
 | 
					          <Form.Button onClick={() => {
 | 
				
			||||||
 | 
					            deleteHistoryLogs().then();
 | 
				
			||||||
 | 
					          }}>清理历史日志</Form.Button>
 | 
				
			||||||
 | 
					          <Divider />
 | 
				
			||||||
          <Header as='h3'>
 | 
					          <Header as='h3'>
 | 
				
			||||||
            监控设置
 | 
					            监控设置
 | 
				
			||||||
          </Header>
 | 
					          </Header>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
 | 
				
			|||||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
 | 
					import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
 | 
				
			||||||
import Turnstile from 'react-turnstile';
 | 
					import Turnstile from 'react-turnstile';
 | 
				
			||||||
import { UserContext } from '../context/User';
 | 
					import { UserContext } from '../context/User';
 | 
				
			||||||
 | 
					import { onGitHubOAuthClicked } from './utils';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const PersonalSetting = () => {
 | 
					const PersonalSetting = () => {
 | 
				
			||||||
  const [userState, userDispatch] = useContext(UserContext);
 | 
					  const [userState, userDispatch] = useContext(UserContext);
 | 
				
			||||||
@@ -130,12 +131,6 @@ const PersonalSetting = () => {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const openGitHubOAuth = () => {
 | 
					 | 
				
			||||||
    window.open(
 | 
					 | 
				
			||||||
      `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
 | 
					 | 
				
			||||||
    );
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  const sendVerificationCode = async () => {
 | 
					  const sendVerificationCode = async () => {
 | 
				
			||||||
    setDisableButton(true);
 | 
					    setDisableButton(true);
 | 
				
			||||||
    if (inputs.email === '') return;
 | 
					    if (inputs.email === '') return;
 | 
				
			||||||
@@ -249,7 +244,7 @@ const PersonalSetting = () => {
 | 
				
			|||||||
      </Modal>
 | 
					      </Modal>
 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
        status.github_oauth && (
 | 
					        status.github_oauth && (
 | 
				
			||||||
          <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
 | 
					          <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      <Button
 | 
					      <Button
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -96,7 +96,7 @@ const TokensTable = () => {
 | 
				
			|||||||
    let nextUrl;
 | 
					    let nextUrl;
 | 
				
			||||||
  
 | 
					  
 | 
				
			||||||
    if (nextLink) {
 | 
					    if (nextLink) {
 | 
				
			||||||
      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
 | 
					      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
					      nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					import { API, showError } from '../helpers';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function getOAuthState() {
 | 
				
			||||||
 | 
					  const res = await API.get('/api/oauth/state');
 | 
				
			||||||
 | 
					  const { success, message, data } = res.data;
 | 
				
			||||||
 | 
					  if (success) {
 | 
				
			||||||
 | 
					    return data;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    showError(message);
 | 
				
			||||||
 | 
					    return '';
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export async function onGitHubOAuthClicked(github_client_id) {
 | 
				
			||||||
 | 
					  const state = await getOAuthState();
 | 
				
			||||||
 | 
					  if (!state) return;
 | 
				
			||||||
 | 
					  window.open(
 | 
				
			||||||
 | 
					    `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -7,7 +7,11 @@ export const CHANNEL_OPTIONS = [
 | 
				
			|||||||
  { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
 | 
					  { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
 | 
				
			||||||
  { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
 | 
					  { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
 | 
				
			||||||
  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
					  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
				
			||||||
 | 
					  { key: 19, text: '360 智脑', value: 19, color: 'blue' },
 | 
				
			||||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
					  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
				
			||||||
 | 
					  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
				
			||||||
 | 
					  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
				
			||||||
 | 
					  { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
 | 
				
			||||||
  { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
 | 
					  { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
 | 
				
			||||||
  { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
 | 
					  { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
 | 
				
			||||||
  { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
 | 
					  { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,6 @@
 | 
				
			|||||||
import React, { useEffect, useState } from 'react';
 | 
					import React, { useEffect, useState } from 'react';
 | 
				
			||||||
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
 | 
					import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
 | 
				
			||||||
import { useParams, useNavigate } from 'react-router-dom';
 | 
					import { useNavigate, useParams } from 'react-router-dom';
 | 
				
			||||||
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 | 
					import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 | 
				
			||||||
import { CHANNEL_OPTIONS } from '../../constants';
 | 
					import { CHANNEL_OPTIONS } from '../../constants';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
 | 
				
			|||||||
  'gpt-4-32k-0314': 'gpt-4-32k'
 | 
					  'gpt-4-32k-0314': 'gpt-4-32k'
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					function type2secretPrompt(type) {
 | 
				
			||||||
 | 
					  // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
 | 
				
			||||||
 | 
					  switch (type) {
 | 
				
			||||||
 | 
					    case 15:
 | 
				
			||||||
 | 
					      return '按照如下格式输入:APIKey|SecretKey';
 | 
				
			||||||
 | 
					    case 18:
 | 
				
			||||||
 | 
					      return '按照如下格式输入:APPID|APISecret|APIKey';
 | 
				
			||||||
 | 
					    case 22:
 | 
				
			||||||
 | 
					      return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      return '请输入渠道对应的鉴权密钥';
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const EditChannel = () => {
 | 
					const EditChannel = () => {
 | 
				
			||||||
  const params = useParams();
 | 
					  const params = useParams();
 | 
				
			||||||
  const navigate = useNavigate();
 | 
					  const navigate = useNavigate();
 | 
				
			||||||
@@ -19,7 +33,7 @@ const EditChannel = () => {
 | 
				
			|||||||
  const handleCancel = () => {
 | 
					  const handleCancel = () => {
 | 
				
			||||||
    navigate('/channel');
 | 
					    navigate('/channel');
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  
 | 
					
 | 
				
			||||||
  const originInputs = {
 | 
					  const originInputs = {
 | 
				
			||||||
    name: '',
 | 
					    name: '',
 | 
				
			||||||
    type: 1,
 | 
					    type: 1,
 | 
				
			||||||
@@ -53,7 +67,7 @@ const EditChannel = () => {
 | 
				
			|||||||
          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
 | 
					          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
 | 
				
			||||||
          break;
 | 
					          break;
 | 
				
			||||||
        case 17:
 | 
					        case 17:
 | 
				
			||||||
          localModels = ['qwen-v1', 'qwen-plus-v1'];
 | 
					          localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
 | 
				
			||||||
          break;
 | 
					          break;
 | 
				
			||||||
        case 16:
 | 
					        case 16:
 | 
				
			||||||
          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
					          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
				
			||||||
@@ -61,6 +75,9 @@ const EditChannel = () => {
 | 
				
			|||||||
        case 18:
 | 
					        case 18:
 | 
				
			||||||
          localModels = ['SparkDesk'];
 | 
					          localModels = ['SparkDesk'];
 | 
				
			||||||
          break;
 | 
					          break;
 | 
				
			||||||
 | 
					        case 19:
 | 
				
			||||||
 | 
					          localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
 | 
				
			||||||
 | 
					          break;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      setInputs((inputs) => ({ ...inputs, models: localModels }));
 | 
					      setInputs((inputs) => ({ ...inputs, models: localModels }));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -157,7 +174,7 @@ const EditChannel = () => {
 | 
				
			|||||||
      return;
 | 
					      return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    let localInputs = inputs;
 | 
					    let localInputs = inputs;
 | 
				
			||||||
    if (localInputs.base_url.endsWith('/')) {
 | 
					    if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
 | 
				
			||||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
					      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
					    if (localInputs.type === 3 && localInputs.other === '') {
 | 
				
			||||||
@@ -166,9 +183,6 @@ const EditChannel = () => {
 | 
				
			|||||||
    if (localInputs.type === 18 && localInputs.other === '') {
 | 
					    if (localInputs.type === 18 && localInputs.other === '') {
 | 
				
			||||||
      localInputs.other = 'v2.1';
 | 
					      localInputs.other = 'v2.1';
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (localInputs.model_mapping === '') {
 | 
					 | 
				
			||||||
      localInputs.model_mapping = '{}';
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    let res;
 | 
					    let res;
 | 
				
			||||||
    localInputs.models = localInputs.models.join(',');
 | 
					    localInputs.models = localInputs.models.join(',');
 | 
				
			||||||
    localInputs.group = localInputs.groups.join(',');
 | 
					    localInputs.group = localInputs.groups.join(',');
 | 
				
			||||||
@@ -190,6 +204,24 @@ const EditChannel = () => {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const addCustomModel = () => {
 | 
				
			||||||
 | 
					    if (customModel.trim() === '') return;
 | 
				
			||||||
 | 
					    if (inputs.models.includes(customModel)) return;
 | 
				
			||||||
 | 
					    let localModels = [...inputs.models];
 | 
				
			||||||
 | 
					    localModels.push(customModel);
 | 
				
			||||||
 | 
					    let localModelOptions = [];
 | 
				
			||||||
 | 
					    localModelOptions.push({
 | 
				
			||||||
 | 
					      key: customModel,
 | 
				
			||||||
 | 
					      text: customModel,
 | 
				
			||||||
 | 
					      value: customModel
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    setModelOptions(modelOptions => {
 | 
				
			||||||
 | 
					      return [...modelOptions, ...localModelOptions];
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					    setCustomModel('');
 | 
				
			||||||
 | 
					    handleInputChange(null, { name: 'models', value: localModels });
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return (
 | 
					  return (
 | 
				
			||||||
    <>
 | 
					    <>
 | 
				
			||||||
      <Segment loading={loading}>
 | 
					      <Segment loading={loading}>
 | 
				
			||||||
@@ -292,6 +324,20 @@ const EditChannel = () => {
 | 
				
			|||||||
              </Form.Field>
 | 
					              </Form.Field>
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					          {
 | 
				
			||||||
 | 
					            inputs.type === 21 && (
 | 
				
			||||||
 | 
					              <Form.Field>
 | 
				
			||||||
 | 
					                <Form.Input
 | 
				
			||||||
 | 
					                  label='知识库 ID'
 | 
				
			||||||
 | 
					                  name='other'
 | 
				
			||||||
 | 
					                  placeholder={'请输入知识库 ID,例如:123456'}
 | 
				
			||||||
 | 
					                  onChange={handleInputChange}
 | 
				
			||||||
 | 
					                  value={inputs.other}
 | 
				
			||||||
 | 
					                  autoComplete='new-password'
 | 
				
			||||||
 | 
					                />
 | 
				
			||||||
 | 
					              </Form.Field>
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
          <Form.Field>
 | 
					          <Form.Field>
 | 
				
			||||||
            <Form.Dropdown
 | 
					            <Form.Dropdown
 | 
				
			||||||
              label='模型'
 | 
					              label='模型'
 | 
				
			||||||
@@ -319,29 +365,19 @@ const EditChannel = () => {
 | 
				
			|||||||
            }}>清除所有模型</Button>
 | 
					            }}>清除所有模型</Button>
 | 
				
			||||||
            <Input
 | 
					            <Input
 | 
				
			||||||
              action={
 | 
					              action={
 | 
				
			||||||
                <Button type={'button'} onClick={() => {
 | 
					                <Button type={'button'} onClick={addCustomModel}>填入</Button>
 | 
				
			||||||
                  if (customModel.trim() === '') return;
 | 
					 | 
				
			||||||
                  if (inputs.models.includes(customModel)) return;
 | 
					 | 
				
			||||||
                  let localModels = [...inputs.models];
 | 
					 | 
				
			||||||
                  localModels.push(customModel);
 | 
					 | 
				
			||||||
                  let localModelOptions = [];
 | 
					 | 
				
			||||||
                  localModelOptions.push({
 | 
					 | 
				
			||||||
                    key: customModel,
 | 
					 | 
				
			||||||
                    text: customModel,
 | 
					 | 
				
			||||||
                    value: customModel
 | 
					 | 
				
			||||||
                  });
 | 
					 | 
				
			||||||
                  setModelOptions(modelOptions => {
 | 
					 | 
				
			||||||
                    return [...modelOptions, ...localModelOptions];
 | 
					 | 
				
			||||||
                  });
 | 
					 | 
				
			||||||
                  setCustomModel('');
 | 
					 | 
				
			||||||
                  handleInputChange(null, { name: 'models', value: localModels });
 | 
					 | 
				
			||||||
                }}>填入</Button>
 | 
					 | 
				
			||||||
              }
 | 
					              }
 | 
				
			||||||
              placeholder='输入自定义模型名称'
 | 
					              placeholder='输入自定义模型名称'
 | 
				
			||||||
              value={customModel}
 | 
					              value={customModel}
 | 
				
			||||||
              onChange={(e, { value }) => {
 | 
					              onChange={(e, { value }) => {
 | 
				
			||||||
                setCustomModel(value);
 | 
					                setCustomModel(value);
 | 
				
			||||||
              }}
 | 
					              }}
 | 
				
			||||||
 | 
					              onKeyDown={(e) => {
 | 
				
			||||||
 | 
					                if (e.key === 'Enter') {
 | 
				
			||||||
 | 
					                  addCustomModel();
 | 
				
			||||||
 | 
					                  e.preventDefault();
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					              }}
 | 
				
			||||||
            />
 | 
					            />
 | 
				
			||||||
          </div>
 | 
					          </div>
 | 
				
			||||||
          <Form.Field>
 | 
					          <Form.Field>
 | 
				
			||||||
@@ -372,7 +408,7 @@ const EditChannel = () => {
 | 
				
			|||||||
                label='密钥'
 | 
					                label='密钥'
 | 
				
			||||||
                name='key'
 | 
					                name='key'
 | 
				
			||||||
                required
 | 
					                required
 | 
				
			||||||
                placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
 | 
					                placeholder={type2secretPrompt(inputs.type)}
 | 
				
			||||||
                onChange={handleInputChange}
 | 
					                onChange={handleInputChange}
 | 
				
			||||||
                value={inputs.key}
 | 
					                value={inputs.key}
 | 
				
			||||||
                autoComplete='new-password'
 | 
					                autoComplete='new-password'
 | 
				
			||||||
@@ -390,7 +426,7 @@ const EditChannel = () => {
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
          {
 | 
					          {
 | 
				
			||||||
            inputs.type !== 3 && inputs.type !== 8 && (
 | 
					            inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
 | 
				
			||||||
              <Form.Field>
 | 
					              <Form.Field>
 | 
				
			||||||
                <Form.Input
 | 
					                <Form.Input
 | 
				
			||||||
                  label='代理'
 | 
					                  label='代理'
 | 
				
			||||||
@@ -403,6 +439,20 @@ const EditChannel = () => {
 | 
				
			|||||||
              </Form.Field>
 | 
					              </Form.Field>
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					          {
 | 
				
			||||||
 | 
					            inputs.type === 22 && (
 | 
				
			||||||
 | 
					              <Form.Field>
 | 
				
			||||||
 | 
					                <Form.Input
 | 
				
			||||||
 | 
					                  label='私有部署地址'
 | 
				
			||||||
 | 
					                  name='base_url'
 | 
				
			||||||
 | 
					                  placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
 | 
				
			||||||
 | 
					                  onChange={handleInputChange}
 | 
				
			||||||
 | 
					                  value={inputs.base_url}
 | 
				
			||||||
 | 
					                  autoComplete='new-password'
 | 
				
			||||||
 | 
					                />
 | 
				
			||||||
 | 
					              </Form.Field>
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
          <Button onClick={handleCancel}>取消</Button>
 | 
					          <Button onClick={handleCancel}>取消</Button>
 | 
				
			||||||
          <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
 | 
					          <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
 | 
				
			||||||
        </Form>
 | 
					        </Form>
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user