mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			48 Commits
		
	
	
		
			v0.4.9-alp
			...
			v0.5.0-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					26c6719ea3 | ||
| 
						 | 
					c87e05bfc2 | ||
| 
						 | 
					e6938bd236 | ||
| 
						 | 
					8f721d67a5 | ||
| 
						 | 
					fcc1e2d568 | ||
| 
						 | 
					9a1db61675 | ||
| 
						 | 
					3c940113ab | ||
| 
						 | 
					0495b9a0d7 | ||
| 
						 | 
					12a0e7105e | ||
| 
						 | 
					e628b643cd | ||
| 
						 | 
					675847bf98 | ||
| 
						 | 
					2ff15baf66 | ||
| 
						 | 
					4139a7036f | ||
| 
						 | 
					02da0b51f8 | ||
| 
						 | 
					35cfebee12 | ||
| 
						 | 
					0e088f7c3e | ||
| 
						 | 
					f61d326721 | ||
| 
						 | 
					74b06b643a | ||
| 
						 | 
					ccf7709e23 | ||
| 
						 | 
					d592e2c8b8 | ||
| 
						 | 
					b520b54625 | ||
| 
						 | 
					81c5901123 | ||
| 
						 | 
					abc53cb208 | ||
| 
						 | 
					2b17bb8dd7 | ||
| 
						 | 
					ea73201b6f | ||
| 
						 | 
					6215d2e71c | ||
| 
						 | 
					d17bdc40a7 | ||
| 
						 | 
					280df27705 | ||
| 
						 | 
					991f5bf4ee | ||
| 
						 | 
					701aaba191 | ||
| 
						 | 
					3bab5b48bf | ||
| 
						 | 
					f3bccee3b5 | ||
| 
						 | 
					d84b0b0f5d | ||
| 
						 | 
					d383302e8a | ||
| 
						 | 
					04f40def2f | ||
| 
						 | 
					c48b7bc0f5 | ||
| 
						 | 
					b09daf5ec1 | ||
| 
						 | 
					c90c0ecef4 | ||
| 
						 | 
					1ab5fb7d2d | ||
| 
						 | 
					f769711c19 | ||
| 
						 | 
					edc5156693 | ||
| 
						 | 
					9ec6506c32 | ||
| 
						 | 
					f387cc5ead | ||
| 
						 | 
					569b68c43b | ||
| 
						 | 
					f0c40a6cd0 | ||
| 
						 | 
					0cea9e6a6f | ||
| 
						 | 
					b1b3651e84 | ||
| 
						 | 
					8f6bd51f58 | 
@@ -285,6 +285,10 @@ If the channel ID is not provided, load balancing will be used to distribute the
 | 
			
		||||
## Note
 | 
			
		||||
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.
 | 
			
		||||
 | 
			
		||||
This project is open-sourced under the MIT license. One must somehow retain the copyright information of One API.
 | 
			
		||||
This project is released under the MIT license. Based on this, attribution and a link to this project must be included at the bottom of the page.
 | 
			
		||||
 | 
			
		||||
The same applies to derivative projects based on this project.
 | 
			
		||||
 | 
			
		||||
If you do not wish to include attribution, prior authorization must be obtained.
 | 
			
		||||
 | 
			
		||||
According to the MIT license, users should bear the risk and responsibility of using this project, and the developer of this open-source project is not responsible for this.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										42
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										42
									
								
								README.md
									
									
									
									
									
								
							@@ -51,6 +51,8 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
  <a href="https://iamazing.cn/page/reward">赞赏支持</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **Note**:本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
 | 
			
		||||
 | 
			
		||||
> **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
 | 
			
		||||
 | 
			
		||||
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
 | 
			
		||||
@@ -59,11 +61,16 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
1. 支持多种 API 访问渠道:
 | 
			
		||||
   + [x] OpenAI 官方通道(支持配置镜像)
 | 
			
		||||
   + [x] **Azure OpenAI API**
 | 
			
		||||
   + [x] [Anthropic Claude 系列模型](https://anthropic.com)
 | 
			
		||||
   + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
 | 
			
		||||
   + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
 | 
			
		||||
   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
			
		||||
   + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
			
		||||
   + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
 | 
			
		||||
   + [x] 自定义渠道:例如各种未收录的第三方代理服务
 | 
			
		||||
2. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
@@ -78,21 +85,26 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
12. 支持以美元为单位显示额度。
 | 
			
		||||
13. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
14. 支持模型映射,重定向用户的请求模型。
 | 
			
		||||
15. 支持丰富的**自定义**设置,
 | 
			
		||||
15. 支持失败自动重试。
 | 
			
		||||
16. 支持绘图接口。
 | 
			
		||||
17. 支持丰富的**自定义**设置,
 | 
			
		||||
    1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
    2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
16. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
17. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
18. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
18. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
19. 支持 Cloudflare Turnstile 用户校验。
 | 
			
		||||
20. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
19. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
 | 
			
		||||
## 部署
 | 
			
		||||
### 基于 Docker 进行部署
 | 
			
		||||
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
 | 
			
		||||
 | 
			
		||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
 | 
			
		||||
 | 
			
		||||
如果你的并发量较大,推荐设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
 | 
			
		||||
 | 
			
		||||
更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
 | 
			
		||||
 | 
			
		||||
`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
 | 
			
		||||
@@ -158,8 +170,8 @@ sudo service nginx restart
 | 
			
		||||
### 多机部署
 | 
			
		||||
1. 所有服务器 `SESSION_SECRET` 设置一样的值。
 | 
			
		||||
2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。
 | 
			
		||||
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`。
 | 
			
		||||
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置。
 | 
			
		||||
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
 | 
			
		||||
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。
 | 
			
		||||
5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
 | 
			
		||||
6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。
 | 
			
		||||
7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
 | 
			
		||||
@@ -182,7 +194,7 @@ sudo service nginx restart
 | 
			
		||||
docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
注意修改端口号和 `BASE_URL`。
 | 
			
		||||
注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。
 | 
			
		||||
 | 
			
		||||
#### ChatGPT Web
 | 
			
		||||
项目主页:https://github.com/Chanzhaoyu/chatgpt-web
 | 
			
		||||
@@ -231,6 +243,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
 | 
			
		||||
等到系统启动后,使用 `root` 用户登录系统并做进一步的配置。
 | 
			
		||||
 | 
			
		||||
**Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。
 | 
			
		||||
 | 
			
		||||
## 使用方法
 | 
			
		||||
在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。
 | 
			
		||||
 | 
			
		||||
@@ -261,7 +275,10 @@ graph LR
 | 
			
		||||
   + 例子:`SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 8.0 版本。
 | 
			
		||||
   + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
 | 
			
		||||
4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址,仅限从服务器设置。
 | 
			
		||||
   + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。
 | 
			
		||||
   + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
 | 
			
		||||
   + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。
 | 
			
		||||
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
 | 
			
		||||
   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
			
		||||
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
 | 
			
		||||
   + 例子:`SYNC_FREQUENCY=60`
 | 
			
		||||
@@ -308,13 +325,16 @@ https://openai.justsong.cn
 | 
			
		||||
5. ChatGPT Next Web 报错:`Failed to fetch`
 | 
			
		||||
   + 部署的时候不要设置 `BASE_URL`。
 | 
			
		||||
   + 检查你的接口地址和 API Key 有没有填对。
 | 
			
		||||
6. 报错:`当前分组负载已饱和,请稍后再试`
 | 
			
		||||
   + 上游通道 429 了。
 | 
			
		||||
 | 
			
		||||
## 相关项目
 | 
			
		||||
[FastGPT](https://github.com/c121914yu/FastGPT): 三分钟搭建 AI 知识库
 | 
			
		||||
 | 
			
		||||
## 注意
 | 
			
		||||
本项目为开源项目,请在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
 | 
			
		||||
 | 
			
		||||
本项目使用 MIT 协议进行开源,请以某种方式保留 One API 的版权信息。
 | 
			
		||||
本项目使用 MIT 协议进行开源,**在此基础上**,必须在页面底部保留署名以及指向本项目的链接。如果不想保留署名,必须首先获得授权。
 | 
			
		||||
 | 
			
		||||
同样适用于基于本项目的二开项目。
 | 
			
		||||
 | 
			
		||||
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
 | 
			
		||||
@@ -1,13 +1,15 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
if [ $# -ne 3 ]; then
 | 
			
		||||
  echo "Usage: time_test.sh <domain> <key> <count>"
 | 
			
		||||
if [ $# -lt 3 ]; then
 | 
			
		||||
  echo "Usage: time_test.sh <domain> <key> <count> [<model>]"
 | 
			
		||||
  exit 1
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
domain=$1
 | 
			
		||||
key=$2
 | 
			
		||||
count=$3
 | 
			
		||||
model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo
 | 
			
		||||
 | 
			
		||||
total_time=0
 | 
			
		||||
times=()
 | 
			
		||||
 | 
			
		||||
@@ -16,7 +18,7 @@ for ((i=1; i<=count; i++)); do
 | 
			
		||||
           https://"$domain"/v1/chat/completions \
 | 
			
		||||
           -H "Content-Type: application/json" \
 | 
			
		||||
           -H "Authorization: Bearer $key" \
 | 
			
		||||
           -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "gpt-3.5-turbo", "stream": false, "max_tokens": 1}')
 | 
			
		||||
           -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}')
 | 
			
		||||
  http_code=$(echo "$result" | awk '{print $1}')
 | 
			
		||||
  time=$(echo "$result" | awk '{print $2}')
 | 
			
		||||
  echo "HTTP status code: $http_code, Time taken: $time"
 | 
			
		||||
 
 | 
			
		||||
@@ -67,6 +67,8 @@ var ChannelDisableThreshold = 5.0
 | 
			
		||||
var AutomaticDisableChannelEnabled = false
 | 
			
		||||
var QuotaRemindThreshold = 1000
 | 
			
		||||
var PreConsumedQuota = 500
 | 
			
		||||
var ApproximateTokenEnabled = false
 | 
			
		||||
var RetryTimes = 0
 | 
			
		||||
 | 
			
		||||
var RootUserEmail = ""
 | 
			
		||||
 | 
			
		||||
@@ -148,6 +150,10 @@ const (
 | 
			
		||||
	ChannelTypeAIProxy   = 10
 | 
			
		||||
	ChannelTypePaLM      = 11
 | 
			
		||||
	ChannelTypeAPI2GPT   = 12
 | 
			
		||||
	ChannelTypeAIGC2D    = 13
 | 
			
		||||
	ChannelTypeAnthropic = 14
 | 
			
		||||
	ChannelTypeBaidu     = 15
 | 
			
		||||
	ChannelTypeZhipu     = 16
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
@@ -155,7 +161,7 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"https://api.openai.com",        // 1
 | 
			
		||||
	"https://oa.api2d.net",          // 2
 | 
			
		||||
	"",                              // 3
 | 
			
		||||
	"https://api.openai-proxy.org", // 4
 | 
			
		||||
	"https://api.closeai-proxy.xyz", // 4
 | 
			
		||||
	"https://api.openai-sb.com",     // 5
 | 
			
		||||
	"https://api.openaimax.com",     // 6
 | 
			
		||||
	"https://api.ohmygpt.com",       // 7
 | 
			
		||||
@@ -164,4 +170,8 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"https://api.aiproxy.io",        // 10
 | 
			
		||||
	"",                              // 11
 | 
			
		||||
	"https://api.api2gpt.com",       // 12
 | 
			
		||||
	"https://api.aigc2d.com",        // 13
 | 
			
		||||
	"https://api.anthropic.com",     // 14
 | 
			
		||||
	"https://aip.baidubce.com",      // 15
 | 
			
		||||
	"https://open.bigmodel.cn",      // 16
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,9 +4,11 @@ import "encoding/json"
 | 
			
		||||
 | 
			
		||||
// ModelRatio
 | 
			
		||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
 | 
			
		||||
// https://openai.com/pricing
 | 
			
		||||
// TODO: when a new api is enabled, check the pricing here
 | 
			
		||||
// 1 === $0.002 / 1K tokens
 | 
			
		||||
// 1 === ¥0.014 / 1k tokens
 | 
			
		||||
var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-4":                   15,
 | 
			
		||||
	"gpt-4-0314":              15,
 | 
			
		||||
@@ -35,6 +37,15 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"text-search-ada-doc-001": 10,
 | 
			
		||||
	"text-moderation-stable":  0.1,
 | 
			
		||||
	"text-moderation-latest":  0.1,
 | 
			
		||||
	"dall-e":                  8,
 | 
			
		||||
	"claude-instant-1":        0.75,
 | 
			
		||||
	"claude-2":                30,
 | 
			
		||||
	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"PaLM-2":                  1,
 | 
			
		||||
	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelRatio2JSONString() string {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,16 +7,19 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetSubscription(c *gin.Context) {
 | 
			
		||||
	var quota int
 | 
			
		||||
	var remainQuota int
 | 
			
		||||
	var usedQuota int
 | 
			
		||||
	var err error
 | 
			
		||||
	var token *model.Token
 | 
			
		||||
	if common.DisplayTokenStatEnabled {
 | 
			
		||||
		tokenId := c.GetInt("token_id")
 | 
			
		||||
		token, err = model.GetTokenById(tokenId)
 | 
			
		||||
		quota = token.RemainQuota
 | 
			
		||||
		remainQuota = token.RemainQuota
 | 
			
		||||
		usedQuota = token.UsedQuota
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		quota, err = model.GetUserQuota(userId)
 | 
			
		||||
		remainQuota, err = model.GetUserQuota(userId)
 | 
			
		||||
		usedQuota, err = model.GetUserUsedQuota(userId)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
@@ -28,6 +31,7 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	quota := remainQuota + usedQuota
 | 
			
		||||
	amount := float64(quota)
 | 
			
		||||
	if common.DisplayInCurrencyEnabled {
 | 
			
		||||
		amount /= common.QuotaPerUnit
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,13 @@ type OpenAIUsageDailyCost struct {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAICreditGrants struct {
 | 
			
		||||
	Object         string  `json:"object"`
 | 
			
		||||
	TotalGranted   float64 `json:"total_granted"`
 | 
			
		||||
	TotalUsed      float64 `json:"total_used"`
 | 
			
		||||
	TotalAvailable float64 `json:"total_available"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIUsageResponse struct {
 | 
			
		||||
	Object string `json:"object"`
 | 
			
		||||
	//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
 | 
			
		||||
@@ -61,6 +68,14 @@ type API2GPTUsageResponse struct {
 | 
			
		||||
	TotalRemaining float64 `json:"total_remaining"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type APGC2DGPTUsageResponse struct {
 | 
			
		||||
	//Grants         interface{} `json:"grants"`
 | 
			
		||||
	Object         string  `json:"object"`
 | 
			
		||||
	TotalAvailable float64 `json:"total_available"`
 | 
			
		||||
	TotalGranted   float64 `json:"total_granted"`
 | 
			
		||||
	TotalUsed      float64 `json:"total_used"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAuthHeader get auth header
 | 
			
		||||
func GetAuthHeader(token string) http.Header {
 | 
			
		||||
	h := http.Header{}
 | 
			
		||||
@@ -92,6 +107,22 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 | 
			
		||||
	return body, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	response := OpenAICreditGrants{}
 | 
			
		||||
	err = json.Unmarshal(body, &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	channel.UpdateBalance(response.TotalAvailable)
 | 
			
		||||
	return response.TotalAvailable, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
@@ -150,8 +181,26 @@ func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	return response.TotalRemaining, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	response := APGC2DGPTUsageResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	channel.UpdateBalance(response.TotalAvailable)
 | 
			
		||||
	return response.TotalAvailable, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.BaseURL == "" {
 | 
			
		||||
		channel.BaseURL = baseURL
 | 
			
		||||
	}
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeOpenAI:
 | 
			
		||||
		if channel.BaseURL != "" {
 | 
			
		||||
@@ -161,12 +210,16 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	case common.ChannelTypeCustom:
 | 
			
		||||
		baseURL = channel.BaseURL
 | 
			
		||||
	case common.ChannelTypeCloseAI:
 | 
			
		||||
		return updateChannelCloseAIBalance(channel)
 | 
			
		||||
	case common.ChannelTypeOpenAISB:
 | 
			
		||||
		return updateChannelOpenAISBBalance(channel)
 | 
			
		||||
	case common.ChannelTypeAIProxy:
 | 
			
		||||
		return updateChannelAIProxyBalance(channel)
 | 
			
		||||
	case common.ChannelTypeAPI2GPT:
 | 
			
		||||
		return updateChannelAPI2GPTBalance(channel)
 | 
			
		||||
	case common.ChannelTypeAIGC2D:
 | 
			
		||||
		return updateChannelAIGC2DBalance(channel)
 | 
			
		||||
	default:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) error {
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		request.Model = "gpt-35-turbo"
 | 
			
		||||
@@ -33,11 +33,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
 | 
			
		||||
 | 
			
		||||
	jsonData, err := json.Marshal(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		req.Header.Set("api-key", channel.Key)
 | 
			
		||||
@@ -48,18 +48,18 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	var response TextResponse
 | 
			
		||||
	err = json.NewDecoder(resp.Body).Decode(&response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return err, nil
 | 
			
		||||
	}
 | 
			
		||||
	if response.Usage.CompletionTokens == 0 {
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func buildTestRequest() *ChatRequest {
 | 
			
		||||
@@ -94,7 +94,7 @@ func TestChannel(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	testRequest := buildTestRequest()
 | 
			
		||||
	tik := time.Now()
 | 
			
		||||
	err = testChannel(channel, *testRequest)
 | 
			
		||||
	err, _ = testChannel(channel, *testRequest)
 | 
			
		||||
	tok := time.Now()
 | 
			
		||||
	milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
	go channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
@@ -158,13 +158,14 @@ func testAllChannels(notify bool) error {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			tik := time.Now()
 | 
			
		||||
			err := testChannel(channel, *testRequest)
 | 
			
		||||
			err, openaiErr := testChannel(channel, *testRequest)
 | 
			
		||||
			tok := time.Now()
 | 
			
		||||
			milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
			if err != nil || milliseconds > disableThreshold {
 | 
			
		||||
			if milliseconds > disableThreshold {
 | 
			
		||||
				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if shouldDisableChannel(openaiErr) {
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -53,6 +54,15 @@ func init() {
 | 
			
		||||
	})
 | 
			
		||||
	// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
	openAIModels = []OpenAIModels{
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "dall-e",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "dall-e",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -242,6 +252,78 @@ func init() {
 | 
			
		||||
			Root:       "code-davinci-edit-001",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "claude-instant-1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "anturopic",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "claude-instant-1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "claude-2",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "anturopic",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "claude-2",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "ERNIE-Bot",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "baidu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "ERNIE-Bot",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "ERNIE-Bot-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "baidu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "ERNIE-Bot-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "PaLM-2",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "google",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "PaLM-2",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "chatglm_pro",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "zhipu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "chatglm_pro",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "chatglm_std",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "zhipu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "chatglm_std",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "chatglm_lite",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "zhipu",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "chatglm_lite",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										203
									
								
								controller/relay-baidu.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										203
									
								
								controller/relay-baidu.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,203 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
 | 
			
		||||
 | 
			
		||||
type BaiduTokenResponse struct {
 | 
			
		||||
	RefreshToken  string `json:"refresh_token"`
 | 
			
		||||
	ExpiresIn     int    `json:"expires_in"`
 | 
			
		||||
	SessionKey    string `json:"session_key"`
 | 
			
		||||
	AccessToken   string `json:"access_token"`
 | 
			
		||||
	Scope         string `json:"scope"`
 | 
			
		||||
	SessionSecret string `json:"session_secret"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduMessage struct {
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduChatRequest struct {
 | 
			
		||||
	Messages []BaiduMessage `json:"messages"`
 | 
			
		||||
	Stream   bool           `json:"stream"`
 | 
			
		||||
	UserId   string         `json:"user_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduError struct {
 | 
			
		||||
	ErrorCode int    `json:"error_code"`
 | 
			
		||||
	ErrorMsg  string `json:"error_msg"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduChatResponse struct {
 | 
			
		||||
	Id               string `json:"id"`
 | 
			
		||||
	Object           string `json:"object"`
 | 
			
		||||
	Created          int64  `json:"created"`
 | 
			
		||||
	Result           string `json:"result"`
 | 
			
		||||
	IsTruncated      bool   `json:"is_truncated"`
 | 
			
		||||
	NeedClearHistory bool   `json:"need_clear_history"`
 | 
			
		||||
	Usage            Usage  `json:"usage"`
 | 
			
		||||
	BaiduError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduChatStreamResponse struct {
 | 
			
		||||
	BaiduChatResponse
 | 
			
		||||
	SentenceId int  `json:"sentence_id"`
 | 
			
		||||
	IsEnd      bool `json:"is_end"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
			
		||||
	messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		messages = append(messages, BaiduMessage{
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &BaiduChatRequest{
 | 
			
		||||
		Messages: messages,
 | 
			
		||||
		Stream:   request.Stream,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: response.Result,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: "stop",
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Id:      response.Id,
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: response.Created,
 | 
			
		||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
			
		||||
		Usage:   response.Usage,
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = baiduResponse.Result
 | 
			
		||||
	choice.FinishReason = "stop"
 | 
			
		||||
	response := ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      baiduResponse.Id,
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: baiduResponse.Created,
 | 
			
		||||
		Model:   "ernie-bot",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var baiduResponse BaiduChatStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage.PromptTokens += baiduResponse.Usage.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += baiduResponse.Usage.TotalTokens
 | 
			
		||||
			response := streamResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var baiduResponse BaiduChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &baiduResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if baiduResponse.ErrorMsg != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: baiduResponse.ErrorMsg,
 | 
			
		||||
				Type:    "baidu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    baiduResponse.ErrorCode,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										221
									
								
								controller/relay-claude.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								controller/relay-claude.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,221 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ClaudeMetadata struct {
 | 
			
		||||
	UserId string `json:"user_id"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClaudeRequest struct {
 | 
			
		||||
	Model             string   `json:"model"`
 | 
			
		||||
	Prompt            string   `json:"prompt"`
 | 
			
		||||
	MaxTokensToSample int      `json:"max_tokens_to_sample"`
 | 
			
		||||
	StopSequences     []string `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Temperature       float64  `json:"temperature,omitempty"`
 | 
			
		||||
	TopP              float64  `json:"top_p,omitempty"`
 | 
			
		||||
	TopK              int      `json:"top_k,omitempty"`
 | 
			
		||||
	//ClaudeMetadata    `json:"metadata,omitempty"`
 | 
			
		||||
	Stream bool `json:"stream,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClaudeError struct {
 | 
			
		||||
	Type    string `json:"type"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClaudeResponse struct {
 | 
			
		||||
	Completion string      `json:"completion"`
 | 
			
		||||
	StopReason string      `json:"stop_reason"`
 | 
			
		||||
	Model      string      `json:"model"`
 | 
			
		||||
	Error      ClaudeError `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func stopReasonClaude2OpenAI(reason string) string {
 | 
			
		||||
	switch reason {
 | 
			
		||||
	case "stop_sequence":
 | 
			
		||||
		return "stop"
 | 
			
		||||
	case "max_tokens":
 | 
			
		||||
		return "length"
 | 
			
		||||
	default:
 | 
			
		||||
		return reason
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 | 
			
		||||
	claudeRequest := ClaudeRequest{
 | 
			
		||||
		Model:             textRequest.Model,
 | 
			
		||||
		Prompt:            "",
 | 
			
		||||
		MaxTokensToSample: textRequest.MaxTokens,
 | 
			
		||||
		StopSequences:     nil,
 | 
			
		||||
		Temperature:       textRequest.Temperature,
 | 
			
		||||
		TopP:              textRequest.TopP,
 | 
			
		||||
		Stream:            textRequest.Stream,
 | 
			
		||||
	}
 | 
			
		||||
	if claudeRequest.MaxTokensToSample == 0 {
 | 
			
		||||
		claudeRequest.MaxTokensToSample = 1000000
 | 
			
		||||
	}
 | 
			
		||||
	prompt := ""
 | 
			
		||||
	for _, message := range textRequest.Messages {
 | 
			
		||||
		if message.Role == "user" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
 | 
			
		||||
		} else if message.Role == "assistant" {
 | 
			
		||||
			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 | 
			
		||||
		} else {
 | 
			
		||||
			// ignore other roles
 | 
			
		||||
		}
 | 
			
		||||
		prompt += "\n\nAssistant:"
 | 
			
		||||
	}
 | 
			
		||||
	claudeRequest.Prompt = prompt
 | 
			
		||||
	return &claudeRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = claudeResponse.Completion
 | 
			
		||||
	choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
 | 
			
		||||
	var response ChatCompletionsStreamResponse
 | 
			
		||||
	response.Object = "chat.completion.chunk"
 | 
			
		||||
	response.Model = claudeResponse.Model
 | 
			
		||||
	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: strings.TrimPrefix(claudeResponse.Completion, " "),
 | 
			
		||||
			Name:    nil,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
			
		||||
	createdTime := common.GetTimestamp()
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
 | 
			
		||||
			return i + 4, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if !strings.HasPrefix(data, "event: completion") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var claudeResponse ClaudeResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			responseText += claudeResponse.Completion
 | 
			
		||||
			response := streamResponseClaude2OpenAI(&claudeResponse)
 | 
			
		||||
			response.Id = responseId
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	var claudeResponse ClaudeResponse
 | 
			
		||||
	err = json.Unmarshal(responseBody, &claudeResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if claudeResponse.Error.Type != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: claudeResponse.Error.Message,
 | 
			
		||||
				Type:    claudeResponse.Error.Type,
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    claudeResponse.Error.Type,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseClaude2OpenAI(&claudeResponse)
 | 
			
		||||
	completionTokens := countTokenText(claudeResponse.Completion, model)
 | 
			
		||||
	usage := Usage{
 | 
			
		||||
		PromptTokens:     promptTokens,
 | 
			
		||||
		CompletionTokens: completionTokens,
 | 
			
		||||
		TotalTokens:      promptTokens + completionTokens,
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse.Usage = usage
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
@@ -1,34 +1,181 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	// TODO: this part is not finished
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, c.Request.RequestURI, c.Request.Body)
 | 
			
		||||
	imageModel := "dall-e"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
 | 
			
		||||
	var imageRequest ImageRequest
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Prompt validation
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Not "256x256", "512x512", or "1024x1024"
 | 
			
		||||
	if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
 | 
			
		||||
		return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// N should between 1 and 10
 | 
			
		||||
	if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
 | 
			
		||||
		return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	isModelMapped := false
 | 
			
		||||
	if modelMapping != "" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap[imageModel] != "" {
 | 
			
		||||
			imageModel = modelMap[imageModel]
 | 
			
		||||
			isModelMapped = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
 | 
			
		||||
	var requestBody io.Reader
 | 
			
		||||
	if isModelMapped {
 | 
			
		||||
		jsonStr, err := json.Marshal(imageRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	} else {
 | 
			
		||||
		requestBody = c.Request.Body
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelRatio := common.GetModelRatio(imageModel)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
			
		||||
 | 
			
		||||
	sizeRatio := 1.0
 | 
			
		||||
	// Size
 | 
			
		||||
	if imageRequest.Size == "256x256" {
 | 
			
		||||
		sizeRatio = 1
 | 
			
		||||
	} else if imageRequest.Size == "512x512" {
 | 
			
		||||
		sizeRatio = 1.125
 | 
			
		||||
	} else if imageRequest.Size == "1024x1024" {
 | 
			
		||||
		sizeRatio = 1.25
 | 
			
		||||
	}
 | 
			
		||||
	quota := int(ratio*sizeRatio*1000) * imageRequest.N
 | 
			
		||||
 | 
			
		||||
	if consumeQuota && userQuota-quota < 0 {
 | 
			
		||||
		return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusOK)
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse ImageResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quota)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										133
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,133 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
				switch relayMode {
 | 
			
		||||
				case RelayModeChatCompletions:
 | 
			
		||||
					var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						responseText += choice.Delta.Content
 | 
			
		||||
					}
 | 
			
		||||
				case RelayModeCompletions:
 | 
			
		||||
					var streamResponse CompletionsStreamResponse
 | 
			
		||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						responseText += choice.Text
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			if strings.HasPrefix(data, "data: [DONE]") {
 | 
			
		||||
				data = data[:12]
 | 
			
		||||
			}
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
		if textResponse.Error.Type != "" {
 | 
			
		||||
			return &OpenAIErrorWithStatusCode{
 | 
			
		||||
				OpenAIError: textResponse.Error,
 | 
			
		||||
				StatusCode:  resp.StatusCode,
 | 
			
		||||
			}, nil
 | 
			
		||||
		}
 | 
			
		||||
		// Reset response body
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
	// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
	// So the client will be confused by the response.
 | 
			
		||||
	// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err := io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &textResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -1,10 +1,17 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 | 
			
		||||
 | 
			
		||||
type PaLMChatMessage struct {
 | 
			
		||||
	Author  string `json:"author"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
@@ -15,45 +22,188 @@ type PaLMFilter struct {
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
			
		||||
type PaLMChatRequest struct {
 | 
			
		||||
	Prompt         []Message `json:"prompt"`
 | 
			
		||||
	Temperature    float64   `json:"temperature"`
 | 
			
		||||
	CandidateCount int       `json:"candidateCount"`
 | 
			
		||||
	TopP           float64   `json:"topP"`
 | 
			
		||||
	TopK           int       `json:"topK"`
 | 
			
		||||
type PaLMPrompt struct {
 | 
			
		||||
	Messages []PaLMChatMessage `json:"messages"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PaLMChatRequest struct {
 | 
			
		||||
	Prompt         PaLMPrompt `json:"prompt"`
 | 
			
		||||
	Temperature    float64    `json:"temperature,omitempty"`
 | 
			
		||||
	CandidateCount int        `json:"candidateCount,omitempty"`
 | 
			
		||||
	TopP           float64    `json:"topP,omitempty"`
 | 
			
		||||
	TopK           int        `json:"topK,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PaLMError struct {
 | 
			
		||||
	Code    int    `json:"code"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
	Status  string `json:"status"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 | 
			
		||||
type PaLMChatResponse struct {
 | 
			
		||||
	Candidates []Message    `json:"candidates"`
 | 
			
		||||
	Candidates []PaLMChatMessage `json:"candidates"`
 | 
			
		||||
	Messages   []Message         `json:"messages"`
 | 
			
		||||
	Filters    []PaLMFilter      `json:"filters"`
 | 
			
		||||
	Error      PaLMError         `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
 | 
			
		||||
	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
 | 
			
		||||
	for _, message := range openAIRequest.Messages {
 | 
			
		||||
		var author string
 | 
			
		||||
		if message.Role == "user" {
 | 
			
		||||
			author = "0"
 | 
			
		||||
		} else {
 | 
			
		||||
			author = "1"
 | 
			
		||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
 | 
			
		||||
	palmRequest := PaLMChatRequest{
 | 
			
		||||
		Prompt: PaLMPrompt{
 | 
			
		||||
			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
 | 
			
		||||
		},
 | 
			
		||||
		Temperature:    textRequest.Temperature,
 | 
			
		||||
		CandidateCount: textRequest.N,
 | 
			
		||||
		TopP:           textRequest.TopP,
 | 
			
		||||
		TopK:           textRequest.MaxTokens,
 | 
			
		||||
	}
 | 
			
		||||
		messages = append(messages, PaLMChatMessage{
 | 
			
		||||
			Author:  author,
 | 
			
		||||
	for _, message := range textRequest.Messages {
 | 
			
		||||
		palmMessage := PaLMChatMessage{
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		}
 | 
			
		||||
		if message.Role == "user" {
 | 
			
		||||
			palmMessage.Author = "0"
 | 
			
		||||
		} else {
 | 
			
		||||
			palmMessage.Author = "1"
 | 
			
		||||
		}
 | 
			
		||||
		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
 | 
			
		||||
	}
 | 
			
		||||
	return &palmRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
 | 
			
		||||
	}
 | 
			
		||||
	for i, candidate := range response.Candidates {
 | 
			
		||||
		choice := OpenAITextResponseChoice{
 | 
			
		||||
			Index: i,
 | 
			
		||||
			Message: Message{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
				Content: candidate.Content,
 | 
			
		||||
			},
 | 
			
		||||
			FinishReason: "stop",
 | 
			
		||||
		}
 | 
			
		||||
		fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	if len(palmResponse.Candidates) > 0 {
 | 
			
		||||
		choice.Delta.Content = palmResponse.Candidates[0].Content
 | 
			
		||||
	}
 | 
			
		||||
	choice.FinishReason = "stop"
 | 
			
		||||
	var response ChatCompletionsStreamResponse
 | 
			
		||||
	response.Object = "chat.completion.chunk"
 | 
			
		||||
	response.Model = "palm2"
 | 
			
		||||
	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 | 
			
		||||
	createdTime := common.GetTimestamp()
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("error reading stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("error closing stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		var palmResponse PaLMChatResponse
 | 
			
		||||
		err = json.Unmarshal(responseBody, &palmResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
		fullTextResponse.Id = responseId
 | 
			
		||||
		fullTextResponse.Created = createdTime
 | 
			
		||||
		if len(palmResponse.Candidates) > 0 {
 | 
			
		||||
			responseText = palmResponse.Candidates[0].Content
 | 
			
		||||
		}
 | 
			
		||||
		jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		dataChan <- string(jsonResponse)
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
	request := PaLMChatRequest{
 | 
			
		||||
		Prompt:         nil,
 | 
			
		||||
		Temperature:    openAIRequest.Temperature,
 | 
			
		||||
		CandidateCount: openAIRequest.N,
 | 
			
		||||
		TopP:           openAIRequest.TopP,
 | 
			
		||||
		TopK:           openAIRequest.MaxTokens,
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
	// TODO: forward request to PaLM & convert response
 | 
			
		||||
	fmt.Print(request)
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	var palmResponse PaLMChatResponse
 | 
			
		||||
	err = json.Unmarshal(responseBody, &palmResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: palmResponse.Error.Message,
 | 
			
		||||
				Type:    palmResponse.Error.Status,
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    palmResponse.Error.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
 | 
			
		||||
	usage := Usage{
 | 
			
		||||
		PromptTokens:     promptTokens,
 | 
			
		||||
		CompletionTokens: completionTokens,
 | 
			
		||||
		TotalTokens:      promptTokens + completionTokens,
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse.Usage = usage
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,17 +1,25 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	APITypeOpenAI = iota
 | 
			
		||||
	APITypeClaude
 | 
			
		||||
	APITypePaLM
 | 
			
		||||
	APITypeBaidu
 | 
			
		||||
	APITypeZhipu
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
@@ -30,6 +38,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	if relayMode == RelayModeModerations && textRequest.Model == "" {
 | 
			
		||||
		textRequest.Model = "text-moderation-latest"
 | 
			
		||||
	}
 | 
			
		||||
	if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
 | 
			
		||||
		textRequest.Model = c.Param("model")
 | 
			
		||||
	}
 | 
			
		||||
	// request validation
 | 
			
		||||
	if textRequest.Model == "" {
 | 
			
		||||
		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
 | 
			
		||||
@@ -67,12 +78,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			isModelMapped = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	apiType := APITypeOpenAI
 | 
			
		||||
	if strings.HasPrefix(textRequest.Model, "claude") {
 | 
			
		||||
		apiType = APITypeClaude
 | 
			
		||||
	} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
 | 
			
		||||
		apiType = APITypeBaidu
 | 
			
		||||
	} else if strings.HasPrefix(textRequest.Model, "PaLM") {
 | 
			
		||||
		apiType = APITypePaLM
 | 
			
		||||
	} else if strings.HasPrefix(textRequest.Model, "chatglm_") {
 | 
			
		||||
		apiType = APITypeZhipu
 | 
			
		||||
	}
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if channelType == common.ChannelTypeAzure {
 | 
			
		||||
			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 | 
			
		||||
			query := c.Request.URL.Query()
 | 
			
		||||
@@ -91,9 +114,35 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			model_ = strings.TrimSuffix(model_, "-0314")
 | 
			
		||||
			model_ = strings.TrimSuffix(model_, "-0613")
 | 
			
		||||
			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 | 
			
		||||
	} else if channelType == common.ChannelTypePaLM {
 | 
			
		||||
		err := relayPaLM(textRequest, c)
 | 
			
		||||
		return err
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
		fullRequestURL = "https://api.anthropic.com/v1/complete"
 | 
			
		||||
		if baseURL != "" {
 | 
			
		||||
			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeBaidu:
 | 
			
		||||
		switch textRequest.Model {
 | 
			
		||||
		case "ERNIE-Bot":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
 | 
			
		||||
		case "ERNIE-Bot-turbo":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
 | 
			
		||||
		case "BLOOMZ-7B":
 | 
			
		||||
			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
 | 
			
		||||
		}
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		fullRequestURL += "?key=" + apiKey
 | 
			
		||||
	case APITypeZhipu:
 | 
			
		||||
		method := "invoke"
 | 
			
		||||
		if textRequest.Stream {
 | 
			
		||||
			method = "sse-invoke"
 | 
			
		||||
		}
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
			
		||||
	}
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	var completionTokens int
 | 
			
		||||
@@ -138,20 +187,63 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	} else {
 | 
			
		||||
		requestBody = c.Request.Body
 | 
			
		||||
	}
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
		claudeRequest := requestOpenAI2Claude(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(claudeRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeBaidu:
 | 
			
		||||
		baiduRequest := requestOpenAI2Baidu(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(baiduRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		palmRequest := requestOpenAI2PaLM(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(palmRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeZhipu:
 | 
			
		||||
		zhipuRequest := requestOpenAI2Zhipu(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(zhipuRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if channelType == common.ChannelTypeAzure {
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
		key = strings.TrimPrefix(key, "Bearer ")
 | 
			
		||||
		req.Header.Set("api-key", key)
 | 
			
		||||
			req.Header.Set("api-key", apiKey)
 | 
			
		||||
		} else {
 | 
			
		||||
			req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
		req.Header.Set("x-api-key", apiKey)
 | 
			
		||||
		anthropicVersion := c.Request.Header.Get("anthropic-version")
 | 
			
		||||
		if anthropicVersion == "" {
 | 
			
		||||
			anthropicVersion = "2023-06-01"
 | 
			
		||||
		}
 | 
			
		||||
		req.Header.Set("anthropic-version", anthropicVersion)
 | 
			
		||||
	case APITypeZhipu:
 | 
			
		||||
		token := getZhipuToken(apiKey)
 | 
			
		||||
		req.Header.Set("Authorization", token)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
	req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
			
		||||
	//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -179,11 +271,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 | 
			
		||||
				completionRatio = 2
 | 
			
		||||
			}
 | 
			
		||||
			if isStream {
 | 
			
		||||
			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
 | 
			
		||||
				completionTokens = countTokenText(streamResponseText, textRequest.Model)
 | 
			
		||||
			} else {
 | 
			
		||||
				promptTokens = textResponse.Usage.PromptTokens
 | 
			
		||||
				completionTokens = textResponse.Usage.CompletionTokens
 | 
			
		||||
				if apiType == APITypeZhipu {
 | 
			
		||||
					// zhipu's API does not return prompt tokens & completion tokens
 | 
			
		||||
					promptTokens = textResponse.Usage.TotalTokens
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			quota = promptTokens + int(float64(completionTokens)*completionRatio)
 | 
			
		||||
			quota = int(float64(quota) * ratio)
 | 
			
		||||
@@ -215,123 +311,102 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if isStream {
 | 
			
		||||
		scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
			if atEOF && len(data) == 0 {
 | 
			
		||||
				return 0, nil, nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if i := strings.Index(string(data), "\n\n"); i >= 0 {
 | 
			
		||||
				return i + 2, data[0:i], nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if atEOF {
 | 
			
		||||
				return len(data), data, nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		})
 | 
			
		||||
		dataChan := make(chan string)
 | 
			
		||||
		stopChan := make(chan bool)
 | 
			
		||||
		go func() {
 | 
			
		||||
			for scanner.Scan() {
 | 
			
		||||
				data := scanner.Text()
 | 
			
		||||
				if len(data) < 6 { // must be something wrong!
 | 
			
		||||
					common.SysError("invalid stream response: " + data)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				data = data[6:]
 | 
			
		||||
				if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
					switch relayMode {
 | 
			
		||||
					case RelayModeChatCompletions:
 | 
			
		||||
						var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
						err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
			err, responseText := openaiStreamHandler(c, resp, relayMode)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
							common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
							return
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
						for _, choice := range streamResponse.Choices {
 | 
			
		||||
							streamResponseText += choice.Delta.Content
 | 
			
		||||
						}
 | 
			
		||||
					case RelayModeCompletions:
 | 
			
		||||
						var streamResponse CompletionsStreamResponse
 | 
			
		||||
						err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := openaiHandler(c, resp, consumeQuota)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
							common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
							return
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
						for _, choice := range streamResponse.Choices {
 | 
			
		||||
							streamResponseText += choice.Text
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
		}()
 | 
			
		||||
		c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
		c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
		c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
		c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
		c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
		c.Stream(func(w io.Writer) bool {
 | 
			
		||||
			select {
 | 
			
		||||
			case data := <-dataChan:
 | 
			
		||||
				if strings.HasPrefix(data, "data: [DONE]") {
 | 
			
		||||
					data = data[:12]
 | 
			
		||||
				}
 | 
			
		||||
				c.Render(-1, common.CustomEvent{Data: data})
 | 
			
		||||
				return true
 | 
			
		||||
			case <-stopChan:
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, responseText := claudeStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeBaidu:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := baiduStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
			err, usage := baiduHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			}
 | 
			
		||||
			err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			}
 | 
			
		||||
			if textResponse.Error.Type != "" {
 | 
			
		||||
				return &OpenAIErrorWithStatusCode{
 | 
			
		||||
					OpenAIError: textResponse.Error,
 | 
			
		||||
					StatusCode:  resp.StatusCode,
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// Reset response body
 | 
			
		||||
			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
		}
 | 
			
		||||
		// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
		// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
		// So the client will be confused by the response.
 | 
			
		||||
		// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
		_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		if textRequest.Stream { // PaLM2 API does not support stream
 | 
			
		||||
			err, responseText := palmStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeZhipu:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := zhipuStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := zhipuHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
			
		||||
@@ -25,6 +24,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
			
		||||
	return tokenEncoder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 | 
			
		||||
	if common.ApproximateTokenEnabled {
 | 
			
		||||
		return int(float64(len(text)) * 0.38)
 | 
			
		||||
	}
 | 
			
		||||
	return len(tokenEncoder.Encode(text, nil, nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func countTokenMessages(messages []Message, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	// Reference:
 | 
			
		||||
@@ -34,12 +40,9 @@ func countTokenMessages(messages []Message, model string) int {
 | 
			
		||||
	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
 | 
			
		||||
	var tokensPerMessage int
 | 
			
		||||
	var tokensPerName int
 | 
			
		||||
	if strings.HasPrefix(model, "gpt-3.5") {
 | 
			
		||||
	if model == "gpt-3.5-turbo-0301" {
 | 
			
		||||
		tokensPerMessage = 4
 | 
			
		||||
		tokensPerName = -1 // If there's a name, the role is omitted
 | 
			
		||||
	} else if strings.HasPrefix(model, "gpt-4") {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
@@ -47,11 +50,11 @@ func countTokenMessages(messages []Message, model string) int {
 | 
			
		||||
	tokenNum := 0
 | 
			
		||||
	for _, message := range messages {
 | 
			
		||||
		tokenNum += tokensPerMessage
 | 
			
		||||
		tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
 | 
			
		||||
		tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
 | 
			
		||||
		tokenNum += getTokenNum(tokenEncoder, message.Content)
 | 
			
		||||
		tokenNum += getTokenNum(tokenEncoder, message.Role)
 | 
			
		||||
		if message.Name != nil {
 | 
			
		||||
			tokenNum += tokensPerName
 | 
			
		||||
			tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
 | 
			
		||||
			tokenNum += getTokenNum(tokenEncoder, *message.Name)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
 | 
			
		||||
@@ -74,8 +77,7 @@ func countTokenInput(input any, model string) int {
 | 
			
		||||
 | 
			
		||||
func countTokenText(text string, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	token := tokenEncoder.Encode(text, nil, nil)
 | 
			
		||||
	return len(token)
 | 
			
		||||
	return getTokenNum(tokenEncoder, text)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
@@ -89,3 +91,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
 | 
			
		||||
		StatusCode:  statusCode,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func shouldDisableChannel(err *OpenAIError) bool {
 | 
			
		||||
	if !common.AutomaticDisableChannelEnabled {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										290
									
								
								controller/relay-zhipu.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										290
									
								
								controller/relay-zhipu.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,290 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/golang-jwt/jwt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://open.bigmodel.cn/doc/api#chatglm_std
 | 
			
		||||
// chatglm_std, chatglm_lite
 | 
			
		||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
 | 
			
		||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
 | 
			
		||||
 | 
			
		||||
type ZhipuMessage struct {
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ZhipuRequest struct {
 | 
			
		||||
	Prompt      []ZhipuMessage `json:"prompt"`
 | 
			
		||||
	Temperature float64        `json:"temperature,omitempty"`
 | 
			
		||||
	TopP        float64        `json:"top_p,omitempty"`
 | 
			
		||||
	RequestId   string         `json:"request_id,omitempty"`
 | 
			
		||||
	Incremental bool           `json:"incremental,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ZhipuResponseData struct {
 | 
			
		||||
	TaskId     string         `json:"task_id"`
 | 
			
		||||
	RequestId  string         `json:"request_id"`
 | 
			
		||||
	TaskStatus string         `json:"task_status"`
 | 
			
		||||
	Choices    []ZhipuMessage `json:"choices"`
 | 
			
		||||
	Usage      `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ZhipuResponse struct {
 | 
			
		||||
	Code    int               `json:"code"`
 | 
			
		||||
	Msg     string            `json:"msg"`
 | 
			
		||||
	Success bool              `json:"success"`
 | 
			
		||||
	Data    ZhipuResponseData `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ZhipuStreamMetaResponse struct {
 | 
			
		||||
	RequestId  string `json:"request_id"`
 | 
			
		||||
	TaskId     string `json:"task_id"`
 | 
			
		||||
	TaskStatus string `json:"task_status"`
 | 
			
		||||
	Usage      `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type zhipuTokenData struct {
 | 
			
		||||
	Token      string
 | 
			
		||||
	ExpiryTime time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var zhipuTokens sync.Map
 | 
			
		||||
var expSeconds int64 = 24 * 3600
 | 
			
		||||
 | 
			
		||||
func getZhipuToken(apikey string) string {
 | 
			
		||||
	data, ok := zhipuTokens.Load(apikey)
 | 
			
		||||
	if ok {
 | 
			
		||||
		tokenData := data.(zhipuTokenData)
 | 
			
		||||
		if time.Now().Before(tokenData.ExpiryTime) {
 | 
			
		||||
			return tokenData.Token
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	split := strings.Split(apikey, ".")
 | 
			
		||||
	if len(split) != 2 {
 | 
			
		||||
		common.SysError("invalid zhipu key: " + apikey)
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	id := split[0]
 | 
			
		||||
	secret := split[1]
 | 
			
		||||
 | 
			
		||||
	expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
 | 
			
		||||
	expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
 | 
			
		||||
 | 
			
		||||
	timestamp := time.Now().UnixNano() / 1e6
 | 
			
		||||
 | 
			
		||||
	payload := jwt.MapClaims{
 | 
			
		||||
		"api_key":   id,
 | 
			
		||||
		"exp":       expMillis,
 | 
			
		||||
		"timestamp": timestamp,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
 | 
			
		||||
 | 
			
		||||
	token.Header["alg"] = "HS256"
 | 
			
		||||
	token.Header["sign_type"] = "SIGN"
 | 
			
		||||
 | 
			
		||||
	tokenString, err := token.SignedString([]byte(secret))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	zhipuTokens.Store(apikey, zhipuTokenData{
 | 
			
		||||
		Token:      tokenString,
 | 
			
		||||
		ExpiryTime: expiryTime,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return tokenString
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
 | 
			
		||||
	messages := make([]ZhipuMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		messages = append(messages, ZhipuMessage{
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &ZhipuRequest{
 | 
			
		||||
		Prompt:      messages,
 | 
			
		||||
		Temperature: request.Temperature,
 | 
			
		||||
		TopP:        request.TopP,
 | 
			
		||||
		Incremental: false,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Id:      response.Data.TaskId,
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
 | 
			
		||||
		Usage:   response.Data.Usage,
 | 
			
		||||
	}
 | 
			
		||||
	for i, choice := range response.Data.Choices {
 | 
			
		||||
		openaiChoice := OpenAITextResponseChoice{
 | 
			
		||||
			Index: i,
 | 
			
		||||
			Message: Message{
 | 
			
		||||
				Role:    choice.Role,
 | 
			
		||||
				Content: strings.Trim(choice.Content, "\""),
 | 
			
		||||
			},
 | 
			
		||||
			FinishReason: "",
 | 
			
		||||
		}
 | 
			
		||||
		if i == len(response.Data.Choices)-1 {
 | 
			
		||||
			openaiChoice.FinishReason = "stop"
 | 
			
		||||
		}
 | 
			
		||||
		fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = zhipuResponse
 | 
			
		||||
	choice.FinishReason = ""
 | 
			
		||||
	response := ChatCompletionsStreamResponse{
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "chatglm",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = ""
 | 
			
		||||
	choice.FinishReason = "stop"
 | 
			
		||||
	response := ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      zhipuResponse.RequestId,
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "chatglm",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &response, &zhipuResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage *Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	metaChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			data = strings.Trim(data, "\"")
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] == "data:" {
 | 
			
		||||
				dataChan <- data[5:]
 | 
			
		||||
			} else if data[:5] == "meta:" {
 | 
			
		||||
				metaChan <- data[5:]
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			response := streamResponseZhipu2OpenAI(data)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case data := <-metaChan:
 | 
			
		||||
			var zhipuResponse ZhipuStreamMetaResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &zhipuResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage = zhipuUsage
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var zhipuResponse ZhipuResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &zhipuResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if !zhipuResponse.Success {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: zhipuResponse.Msg,
 | 
			
		||||
				Type:    "zhipu_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    zhipuResponse.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -2,10 +2,12 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
@@ -37,6 +39,7 @@ type GeneralOpenAIRequest struct {
 | 
			
		||||
	N           int       `json:"n,omitempty"`
 | 
			
		||||
	Input       any       `json:"input,omitempty"`
 | 
			
		||||
	Instruction string    `json:"instruction,omitempty"`
 | 
			
		||||
	Size        string    `json:"size,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
@@ -53,6 +56,12 @@ type TextRequest struct {
 | 
			
		||||
	//Stream   bool      `json:"stream"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageRequest struct {
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	N      int    `json:"n"`
 | 
			
		||||
	Size   string `json:"size"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Usage struct {
 | 
			
		||||
	PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
	CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
@@ -76,13 +85,40 @@ type TextResponse struct {
 | 
			
		||||
	Error OpenAIError `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatCompletionsStreamResponse struct {
 | 
			
		||||
	Choices []struct {
 | 
			
		||||
type OpenAITextResponseChoice struct {
 | 
			
		||||
	Index        int `json:"index"`
 | 
			
		||||
	Message      `json:"message"`
 | 
			
		||||
	FinishReason string `json:"finish_reason"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAITextResponse struct {
 | 
			
		||||
	Id      string                     `json:"id"`
 | 
			
		||||
	Object  string                     `json:"object"`
 | 
			
		||||
	Created int64                      `json:"created"`
 | 
			
		||||
	Choices []OpenAITextResponseChoice `json:"choices"`
 | 
			
		||||
	Usage   `json:"usage"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageResponse struct {
 | 
			
		||||
	Created int `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
		Url string `json:"url"`
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatCompletionsStreamResponseChoice struct {
 | 
			
		||||
	Delta struct {
 | 
			
		||||
		Content string `json:"content"`
 | 
			
		||||
	} `json:"delta"`
 | 
			
		||||
		FinishReason string `json:"finish_reason"`
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
	FinishReason string `json:"finish_reason,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatCompletionsStreamResponse struct {
 | 
			
		||||
	Id      string                                `json:"id"`
 | 
			
		||||
	Object  string                                `json:"object"`
 | 
			
		||||
	Created int64                                 `json:"created"`
 | 
			
		||||
	Model   string                                `json:"model"`
 | 
			
		||||
	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CompletionsStreamResponse struct {
 | 
			
		||||
@@ -100,6 +136,8 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		relayMode = RelayModeCompletions
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 | 
			
		||||
		relayMode = RelayModeEmbeddings
 | 
			
		||||
	} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
			
		||||
		relayMode = RelayModeEmbeddings
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
		relayMode = RelayModeModerations
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
			
		||||
@@ -115,16 +153,25 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		err = relayTextHelper(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		retryTimesStr := c.Query("retry")
 | 
			
		||||
		retryTimes, _ := strconv.Atoi(retryTimesStr)
 | 
			
		||||
		if retryTimesStr == "" {
 | 
			
		||||
			retryTimes = common.RetryTimes
 | 
			
		||||
		}
 | 
			
		||||
		if retryTimes > 0 {
 | 
			
		||||
			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
 | 
			
		||||
		} else {
 | 
			
		||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
				err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
 | 
			
		||||
			}
 | 
			
		||||
			c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
				"error": err.OpenAIError,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") {
 | 
			
		||||
		if shouldDisableChannel(&err.OpenAIError) {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelName := c.GetString("channel_name")
 | 
			
		||||
			disableChannel(channelId, channelName, err.Message)
 | 
			
		||||
 
 | 
			
		||||
@@ -180,10 +180,10 @@ func UpdateToken(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if token.Status == common.TokenStatusEnabled {
 | 
			
		||||
		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() {
 | 
			
		||||
		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "令牌已过期,无法启用,请先修改令牌过期时间",
 | 
			
		||||
				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										3
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								go.mod
									
									
									
									
									
								
							@@ -11,6 +11,7 @@ require (
 | 
			
		||||
	github.com/gin-gonic/gin v1.9.1
 | 
			
		||||
	github.com/go-playground/validator/v10 v10.14.0
 | 
			
		||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
			
		||||
	github.com/golang-jwt/jwt v3.2.2+incompatible
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.1
 | 
			
		||||
	golang.org/x/crypto v0.9.0
 | 
			
		||||
@@ -20,7 +21,6 @@ require (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 | 
			
		||||
	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
 | 
			
		||||
@@ -32,7 +32,6 @@ require (
 | 
			
		||||
	github.com/go-playground/universal-translator v0.18.1 // indirect
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.6.0 // indirect
 | 
			
		||||
	github.com/goccy/go-json v0.10.2 // indirect
 | 
			
		||||
	github.com/gomodule/redigo v2.0.0+incompatible // indirect
 | 
			
		||||
	github.com/gorilla/context v1.1.1 // indirect
 | 
			
		||||
	github.com/gorilla/securecookie v1.1.1 // indirect
 | 
			
		||||
	github.com/gorilla/sessions v1.2.1 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,5 +1,3 @@
 | 
			
		||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04=
 | 
			
		||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
 | 
			
		||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
			
		||||
@@ -54,10 +52,10 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
 | 
			
		||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
 | 
			
		||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
 | 
			
		||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
 | 
			
		||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
 | 
			
		||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
 | 
			
		||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 | 
			
		||||
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
 | 
			
		||||
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
 | 
			
		||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
 | 
			
		||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 | 
			
		||||
@@ -67,7 +65,6 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8
 | 
			
		||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
 | 
			
		||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
 | 
			
		||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
 | 
			
		||||
github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
 | 
			
		||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
 | 
			
		||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
 | 
			
		||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										49
									
								
								i18n/en.json
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								i18n/en.json
									
									
									
									
									
								
							@@ -36,7 +36,7 @@
 | 
			
		||||
  "通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)",
 | 
			
		||||
  "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。": "The current group load is saturated, please try again later, or upgrade your account to improve service quality.",
 | 
			
		||||
  "令牌名称长度必须在1-20之间": "The length of the token name must be between 1-20",
 | 
			
		||||
  "令牌已过期,无法启用,请先修改令牌过期时间": "The token has expired and cannot be enabled. Please modify the token expiration time first",
 | 
			
		||||
  "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.",
 | 
			
		||||
  "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota",
 | 
			
		||||
  "管理员关闭了密码登录": "The administrator has turned off password login",
 | 
			
		||||
  "无法保存会话信息,请重试": "Unable to save session information, please try again",
 | 
			
		||||
@@ -107,6 +107,11 @@
 | 
			
		||||
  "已禁用": "Disabled",
 | 
			
		||||
  "未知状态": "Unknown status",
 | 
			
		||||
  " 秒": "s",
 | 
			
		||||
  " 分钟 ": " m ",
 | 
			
		||||
  " 小时 ": " h ",
 | 
			
		||||
  " 天 ": " d ",
 | 
			
		||||
  " 个月 ": " M ",
 | 
			
		||||
  " 年 ": " y ",
 | 
			
		||||
  "未测试": "Not tested",
 | 
			
		||||
  "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
 | 
			
		||||
  "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
 | 
			
		||||
@@ -458,5 +463,45 @@
 | 
			
		||||
  "消耗额度": "Used Quota",
 | 
			
		||||
  "可选值": "Optional Values",
 | 
			
		||||
  "渠道不存在:%d": "Channel does not exist: %d",
 | 
			
		||||
  "数据库一致性已被破坏,请联系管理员": "Database consistency has been broken, please contact the administrator"
 | 
			
		||||
  "数据库一致性已被破坏,请联系管理员": "Database consistency has been broken, please contact the administrator",
 | 
			
		||||
  "使用近似的方式估算 token 数以减少计算量": "Estimate the number of tokens in an approximate way to reduce computational load",
 | 
			
		||||
  "请填写ChannelName和ChannelKey!": "Please fill in the ChannelName and ChannelKey!",
 | 
			
		||||
  "请至少选择一个Model!": "Please select at least one Model!",
 | 
			
		||||
  "加载首页内容失败": "Failed to load the homepage content",
 | 
			
		||||
  "加载关于内容失败": "Failed to load the About content",
 | 
			
		||||
  "兑换码更新成功!": "Redemption code updated successfully!",
 | 
			
		||||
  "兑换码创建成功!": "Redemption code created successfully!",
 | 
			
		||||
  "用户账户创建成功!": "User account created successfully!",
 | 
			
		||||
  "生成数量": "Generate quantity",
 | 
			
		||||
  "请输入生成数量": "Please enter the quantity to generate",
 | 
			
		||||
  "创建新用户账户": "Create new user account",
 | 
			
		||||
  "渠道更新成功!": "Channel updated successfully!",
 | 
			
		||||
  "渠道创建成功!": "Channel created successfully!",
 | 
			
		||||
  "请选择分组": "Please select a group",
 | 
			
		||||
  "更新兑换码信息": "Update redemption code information",
 | 
			
		||||
  "创建新的兑换码": "Create a new redemption code",
 | 
			
		||||
  "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group ratio in the system settings page to add a new group:",
 | 
			
		||||
  "未找到所请求的页面": "The requested page was not found",
 | 
			
		||||
  "过期时间格式错误!": "Expiration time format error!",
 | 
			
		||||
  "请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means no limit",
 | 
			
		||||
  "此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:": "This is optional, it's a JSON text, the key is the model name requested by the user, and the value is the model name to be replaced, for example:",
 | 
			
		||||
  "此项可选,输入镜像站地址,格式为:": "This is optional, enter the mirror site address, the format is:",
 | 
			
		||||
  "模型映射": "Model mapping",
 | 
			
		||||
  "请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖": "Please enter the default API version, for example: 2023-03-15-preview, this configuration can be overridden by the actual request query parameters",
 | 
			
		||||
  "默认": "Default",
 | 
			
		||||
  "图片演示": "Image demo",
 | 
			
		||||
  "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)",
 | 
			
		||||
  "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
 | 
			
		||||
  "取消无限额度": "Cancel unlimited quota",
 | 
			
		||||
  "请输入新的剩余额度": "Please enter the new remaining quota",
 | 
			
		||||
  "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code",
 | 
			
		||||
  "请输入用户名": "Please enter username",
 | 
			
		||||
  "请输入显示名称": "Please enter display name",
 | 
			
		||||
  "请输入密码": "Please enter password",
 | 
			
		||||
  "模型部署名称必须和模型名称保持一致": "The model deployment name must be consistent with the model name",
 | 
			
		||||
  ",因为 One API 会把请求体中的 model": ", because One API will take the model in the request body",
 | 
			
		||||
  "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT",
 | 
			
		||||
  "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
 | 
			
		||||
  "Homepage URL 填": "Fill in the Homepage URL",
 | 
			
		||||
  "Authorization callback URL 填": "Fill in the Authorization callback URL"
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,13 @@ package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ModelRequest struct {
 | 
			
		||||
@@ -73,9 +74,19 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					modelRequest.Model = "text-moderation-stable"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = c.Param("model")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "dall-e"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				message := "无可用渠道"
 | 
			
		||||
				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
			
		||||
				if channel != nil {
 | 
			
		||||
					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
			
		||||
					message = "数据库一致性已被破坏,请联系管理员"
 | 
			
		||||
 
 | 
			
		||||
@@ -37,7 +37,7 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
 | 
			
		||||
	err = DB.Omit("key").Where("id = ? or name LIKE ? or key = ?", keyword, keyword+"%", keyword).Find(&channels).Error
 | 
			
		||||
	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
 | 
			
		||||
	return channels, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -34,6 +34,7 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
 | 
			
		||||
	common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
 | 
			
		||||
	common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
 | 
			
		||||
	common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
 | 
			
		||||
	common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
 | 
			
		||||
	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
 | 
			
		||||
	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
 | 
			
		||||
@@ -67,6 +68,7 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["TopUpLink"] = common.TopUpLink
 | 
			
		||||
	common.OptionMap["ChatLink"] = common.ChatLink
 | 
			
		||||
	common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
 | 
			
		||||
	common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
 | 
			
		||||
	common.OptionMapRWMutex.Unlock()
 | 
			
		||||
	loadOptionsFromDatabase()
 | 
			
		||||
}
 | 
			
		||||
@@ -141,6 +143,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
			common.RegisterEnabled = boolValue
 | 
			
		||||
		case "AutomaticDisableChannelEnabled":
 | 
			
		||||
			common.AutomaticDisableChannelEnabled = boolValue
 | 
			
		||||
		case "ApproximateTokenEnabled":
 | 
			
		||||
			common.ApproximateTokenEnabled = boolValue
 | 
			
		||||
		case "LogConsumeEnabled":
 | 
			
		||||
			common.LogConsumeEnabled = boolValue
 | 
			
		||||
		case "DisplayInCurrencyEnabled":
 | 
			
		||||
@@ -193,6 +197,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		common.QuotaRemindThreshold, _ = strconv.Atoi(value)
 | 
			
		||||
	case "PreConsumedQuota":
 | 
			
		||||
		common.PreConsumedQuota, _ = strconv.Atoi(value)
 | 
			
		||||
	case "RetryTimes":
 | 
			
		||||
		common.RetryTimes, _ = strconv.Atoi(value)
 | 
			
		||||
	case "ModelRatio":
 | 
			
		||||
		err = common.UpdateModelRatioByJSONString(value)
 | 
			
		||||
	case "GroupRatio":
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -48,26 +49,27 @@ func Redeem(key string, userId int) (quota int, err error) {
 | 
			
		||||
		return 0, errors.New("无效的 user id")
 | 
			
		||||
	}
 | 
			
		||||
	redemption := &Redemption{}
 | 
			
		||||
	err = DB.Where("`key` = ?", key).First(redemption).Error
 | 
			
		||||
 | 
			
		||||
	err = DB.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		err := DB.Where("`key` = ?", key).First(redemption).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		return 0, errors.New("无效的兑换码")
 | 
			
		||||
			return errors.New("无效的兑换码")
 | 
			
		||||
		}
 | 
			
		||||
		if redemption.Status != common.RedemptionCodeStatusEnabled {
 | 
			
		||||
		return 0, errors.New("该兑换码已被使用")
 | 
			
		||||
			return errors.New("该兑换码已被使用")
 | 
			
		||||
		}
 | 
			
		||||
	err = IncreaseUserQuota(userId, redemption.Quota)
 | 
			
		||||
		err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	go func() {
 | 
			
		||||
		redemption.RedeemedTime = common.GetTimestamp()
 | 
			
		||||
		redemption.Status = common.RedemptionCodeStatusUsed
 | 
			
		||||
		err := redemption.SelectUpdate()
 | 
			
		||||
		return redemption.SelectUpdate()
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
			common.SysError("failed to update redemption status: " + err.Error())
 | 
			
		||||
		return 0, errors.New("兑换失败," + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
 | 
			
		||||
	}()
 | 
			
		||||
	return redemption.Quota, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
package router
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-contrib/gzip"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/middleware"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/gzip"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
@@ -14,6 +15,10 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
	SetDashboardRouter(router)
 | 
			
		||||
	SetRelayRouter(router)
 | 
			
		||||
	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
 | 
			
		||||
	if common.IsMasterNode && frontendBaseUrl != "" {
 | 
			
		||||
		frontendBaseUrl = ""
 | 
			
		||||
		common.SysLog("FRONTEND_BASE_URL is ignored on master node")
 | 
			
		||||
	}
 | 
			
		||||
	if frontendBaseUrl == "" {
 | 
			
		||||
		SetWebRouter(router, buildFS, indexPage)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,10 @@
 | 
			
		||||
package router
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/controller"
 | 
			
		||||
	"one-api/middleware"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
@@ -20,10 +21,11 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
		relayV1Router.POST("/completions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/chat/completions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/edits", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/images/generations", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.GET("/files", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,9 @@ function renderType(type) {
 | 
			
		||||
function renderBalance(type, balance) {
 | 
			
		||||
  switch (type) {
 | 
			
		||||
    case 1: // OpenAI
 | 
			
		||||
      return <span>${balance.toFixed(2)}</span>;
 | 
			
		||||
    case 4: // CloseAI
 | 
			
		||||
      return <span>¥{balance.toFixed(2)}</span>;
 | 
			
		||||
    case 8: // 自定义
 | 
			
		||||
      return <span>${balance.toFixed(2)}</span>;
 | 
			
		||||
    case 5: // OpenAI-SB
 | 
			
		||||
@@ -38,6 +41,8 @@ function renderBalance(type, balance) {
 | 
			
		||||
      return <span>{renderNumber(balance)}</span>;
 | 
			
		||||
    case 12: // API2GPT
 | 
			
		||||
      return <span>¥{balance.toFixed(2)}</span>;
 | 
			
		||||
    case 13: // AIGC2D
 | 
			
		||||
      return <span>{renderNumber(balance)}</span>;
 | 
			
		||||
    default:
 | 
			
		||||
      return <span>不支持</span>;
 | 
			
		||||
  }
 | 
			
		||||
@@ -58,8 +63,8 @@ const ChannelsTable = () => {
 | 
			
		||||
      if (startIdx === 0) {
 | 
			
		||||
        setChannels(data);
 | 
			
		||||
      } else {
 | 
			
		||||
        let newChannels = channels;
 | 
			
		||||
        newChannels.push(...data);
 | 
			
		||||
        let newChannels = [...channels];
 | 
			
		||||
        newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
 | 
			
		||||
        setChannels(newChannels);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
@@ -80,7 +85,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
  const refresh = async () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    await loadChannels(0);
 | 
			
		||||
    await loadChannels(activePage - 1);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ import {
 | 
			
		||||
} from 'semantic-ui-react';
 | 
			
		||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
import { API, getLogo, showError, showSuccess } from '../helpers';
 | 
			
		||||
import { API, getLogo, showError, showSuccess, showInfo } from '../helpers';
 | 
			
		||||
 | 
			
		||||
const LoginForm = () => {
 | 
			
		||||
  const [inputs, setInputs] = useState({
 | 
			
		||||
@@ -76,7 +76,7 @@ const LoginForm = () => {
 | 
			
		||||
  async function handleSubmit(e) {
 | 
			
		||||
    setSubmitted(true);
 | 
			
		||||
    if (username && password) {
 | 
			
		||||
      const res = await API.post('/api/user/login', {
 | 
			
		||||
      const res = await API.post(`/api/user/login`, {
 | 
			
		||||
        username,
 | 
			
		||||
        password,
 | 
			
		||||
      });
 | 
			
		||||
 
 | 
			
		||||
@@ -108,7 +108,7 @@ const LogsTable = () => {
 | 
			
		||||
        setLogs(data);
 | 
			
		||||
      } else {
 | 
			
		||||
        let newLogs = [...logs];
 | 
			
		||||
        newLogs.push(...data);
 | 
			
		||||
        newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
 | 
			
		||||
        setLogs(newLogs);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
 
 | 
			
		||||
@@ -18,7 +18,9 @@ const OperationSetting = () => {
 | 
			
		||||
    ChannelDisableThreshold: 0,
 | 
			
		||||
    LogConsumeEnabled: '',
 | 
			
		||||
    DisplayInCurrencyEnabled: '',
 | 
			
		||||
    DisplayTokenStatEnabled: ''
 | 
			
		||||
    DisplayTokenStatEnabled: '',
 | 
			
		||||
    ApproximateTokenEnabled: '',
 | 
			
		||||
    RetryTimes: 0,
 | 
			
		||||
  });
 | 
			
		||||
  const [originInputs, setOriginInputs] = useState({});
 | 
			
		||||
  let [loading, setLoading] = useState(false);
 | 
			
		||||
@@ -121,6 +123,9 @@ const OperationSetting = () => {
 | 
			
		||||
        if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
 | 
			
		||||
          await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
 | 
			
		||||
        }
 | 
			
		||||
        if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
 | 
			
		||||
          await updateOption('RetryTimes', inputs.RetryTimes);
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
@@ -132,7 +137,7 @@ const OperationSetting = () => {
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            通用设置
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group widths={3}>
 | 
			
		||||
          <Form.Group widths={4}>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='充值链接'
 | 
			
		||||
              name='TopUpLink'
 | 
			
		||||
@@ -161,6 +166,17 @@ const OperationSetting = () => {
 | 
			
		||||
              step='0.01'
 | 
			
		||||
              placeholder='一单位货币能兑换的额度'
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='失败重试次数'
 | 
			
		||||
              name='RetryTimes'
 | 
			
		||||
              type={'number'}
 | 
			
		||||
              step='1'
 | 
			
		||||
              min='0'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              value={inputs.RetryTimes}
 | 
			
		||||
              placeholder='失败重试次数'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
@@ -181,6 +197,12 @@ const OperationSetting = () => {
 | 
			
		||||
              name='DisplayTokenStatEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.ApproximateTokenEnabled === 'true'}
 | 
			
		||||
              label='使用近似的方式估算 token 数以减少计算量'
 | 
			
		||||
              name='ApproximateTokenEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={() => {
 | 
			
		||||
            submitConfig('general').then();
 | 
			
		||||
 
 | 
			
		||||
@@ -45,8 +45,8 @@ const TokensTable = () => {
 | 
			
		||||
      if (startIdx === 0) {
 | 
			
		||||
        setTokens(data);
 | 
			
		||||
      } else {
 | 
			
		||||
        let newTokens = tokens;
 | 
			
		||||
        newTokens.push(...data);
 | 
			
		||||
        let newTokens = [...tokens];
 | 
			
		||||
        newTokens.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
 | 
			
		||||
        setTokens(newTokens);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
@@ -67,7 +67,7 @@ const TokensTable = () => {
 | 
			
		||||
 | 
			
		||||
  const refresh = async () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    await loadTokens(0);
 | 
			
		||||
    await loadTokens(activePage - 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
 
 | 
			
		||||
@@ -183,14 +183,6 @@ const UsersTable = () => {
 | 
			
		||||
            >
 | 
			
		||||
              分组
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortUser('email');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              邮箱地址
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
@@ -233,20 +225,20 @@ const UsersTable = () => {
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={user.email ? user.email : '未绑定邮箱地址'}
 | 
			
		||||
                      key={user.display_name}
 | 
			
		||||
                      key={user.username}
 | 
			
		||||
                      header={user.display_name ? user.display_name : user.username}
 | 
			
		||||
                      trigger={<span>{renderText(user.username, 10)}</span>}
 | 
			
		||||
                      hoverable
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderGroup(user.group)}</Table.Cell>
 | 
			
		||||
                  {/*<Table.Cell>*/}
 | 
			
		||||
                  {/*  {user.email ? <Popup hoverable content={user.email} trigger={<span>{renderText(user.email, 24)}</span>} /> : '无'}*/}
 | 
			
		||||
                  {/*</Table.Cell>*/}
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    {user.email ? <Popup hoverable content={user.email} trigger={<span>{renderText(user.email, 24)}</span>} /> : '无'}
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup content='剩余额度' trigger={<Label>{renderQuota(user.quota)}</Label>} />
 | 
			
		||||
                    <Popup content='已用额度' trigger={<Label>{renderQuota(user.used_quota)}</Label>} />
 | 
			
		||||
                    <Popup content='请求次数' trigger={<Label>{renderNumber(user.request_count)}</Label>} />
 | 
			
		||||
                    <Popup content='剩余额度' trigger={<Label basic>{renderQuota(user.quota)}</Label>} />
 | 
			
		||||
                    <Popup content='已用额度' trigger={<Label basic>{renderQuota(user.used_quota)}</Label>} />
 | 
			
		||||
                    <Popup content='请求次数' trigger={<Label basic>{renderNumber(user.request_count)}</Label>} />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderRole(user.role)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderStatus(user.status)}</Table.Cell>
 | 
			
		||||
@@ -320,7 +312,7 @@ const UsersTable = () => {
 | 
			
		||||
 | 
			
		||||
        <Table.Footer>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
            <Table.HeaderCell colSpan='8'>
 | 
			
		||||
            <Table.HeaderCell colSpan='7'>
 | 
			
		||||
              <Button size='small' as={Link} to='/user/add' loading={loading}>
 | 
			
		||||
                添加新的用户
 | 
			
		||||
              </Button>
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,18 @@
 | 
			
		||||
export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 1, text: 'OpenAI', value: 1, color: 'green' },
 | 
			
		||||
  { key: 8, text: '自定义', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 3, text: 'Azure', value: 3, color: 'olive' },
 | 
			
		||||
  { key: 2, text: 'API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
 | 
			
		||||
  { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
 | 
			
		||||
  { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
 | 
			
		||||
  { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
 | 
			
		||||
  { key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
 | 
			
		||||
  { key: 10, text: 'AI Proxy', value: 10, color: 'purple' },
 | 
			
		||||
  { key: 12, text: 'API2GPT', value: 12, color: 'blue' }
 | 
			
		||||
  { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
 | 
			
		||||
  { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
 | 
			
		||||
  { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
 | 
			
		||||
  { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
 | 
			
		||||
  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
 | 
			
		||||
  { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
 | 
			
		||||
  { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' },
 | 
			
		||||
  { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' },
 | 
			
		||||
  { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' },
 | 
			
		||||
  { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' },
 | 
			
		||||
  { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' },
 | 
			
		||||
  { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }
 | 
			
		||||
];
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { useParams } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 | 
			
		||||
import { CHANNEL_OPTIONS } from '../../constants';
 | 
			
		||||
@@ -27,10 +27,12 @@ const EditChannel = () => {
 | 
			
		||||
  };
 | 
			
		||||
  const [batch, setBatch] = useState(false);
 | 
			
		||||
  const [inputs, setInputs] = useState(originInputs);
 | 
			
		||||
  const [originModelOptions, setOriginModelOptions] = useState([]);
 | 
			
		||||
  const [modelOptions, setModelOptions] = useState([]);
 | 
			
		||||
  const [groupOptions, setGroupOptions] = useState([]);
 | 
			
		||||
  const [basicModels, setBasicModels] = useState([]);
 | 
			
		||||
  const [fullModels, setFullModels] = useState([]);
 | 
			
		||||
  const [customModel, setCustomModel] = useState('');
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
  };
 | 
			
		||||
@@ -62,13 +64,16 @@ const EditChannel = () => {
 | 
			
		||||
  const fetchModels = async () => {
 | 
			
		||||
    try {
 | 
			
		||||
      let res = await API.get(`/api/channel/models`);
 | 
			
		||||
      setModelOptions(res.data.data.map((model) => ({
 | 
			
		||||
      let localModelOptions = res.data.data.map((model) => ({
 | 
			
		||||
        key: model.id,
 | 
			
		||||
        text: model.id,
 | 
			
		||||
        value: model.id
 | 
			
		||||
      })));
 | 
			
		||||
      }));
 | 
			
		||||
      setOriginModelOptions(localModelOptions);
 | 
			
		||||
      setFullModels(res.data.data.map((model) => model.id));
 | 
			
		||||
      setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id));
 | 
			
		||||
      setBasicModels(res.data.data.filter((model) => {
 | 
			
		||||
        return model.id.startsWith('gpt-3') || model.id.startsWith('text-');
 | 
			
		||||
      }).map((model) => model.id));
 | 
			
		||||
    } catch (error) {
 | 
			
		||||
      showError(error.message);
 | 
			
		||||
    }
 | 
			
		||||
@@ -87,6 +92,20 @@ const EditChannel = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    let localModelOptions = [...originModelOptions];
 | 
			
		||||
    inputs.models.forEach((model) => {
 | 
			
		||||
      if (!localModelOptions.find((option) => option.key === model)) {
 | 
			
		||||
        localModelOptions.push({
 | 
			
		||||
          key: model,
 | 
			
		||||
          text: model,
 | 
			
		||||
          value: model
 | 
			
		||||
        });
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    setModelOptions(localModelOptions);
 | 
			
		||||
  }, [originModelOptions, inputs.models]);
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    if (isEdit) {
 | 
			
		||||
      loadChannel().then();
 | 
			
		||||
@@ -145,6 +164,7 @@ const EditChannel = () => {
 | 
			
		||||
            <Form.Select
 | 
			
		||||
              label='类型'
 | 
			
		||||
              name='type'
 | 
			
		||||
              required
 | 
			
		||||
              options={CHANNEL_OPTIONS}
 | 
			
		||||
              value={inputs.type}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
@@ -201,7 +221,7 @@ const EditChannel = () => {
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='镜像'
 | 
			
		||||
                  name='base_url'
 | 
			
		||||
                  placeholder={'请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值'}
 | 
			
		||||
                  placeholder={'此项可选,输入镜像站地址,格式为:https://domain.com'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.base_url}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
@@ -212,6 +232,7 @@ const EditChannel = () => {
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='名称'
 | 
			
		||||
              required
 | 
			
		||||
              name='name'
 | 
			
		||||
              placeholder={'请输入名称'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
@@ -224,6 +245,7 @@ const EditChannel = () => {
 | 
			
		||||
              label='分组'
 | 
			
		||||
              placeholder={'请选择分组'}
 | 
			
		||||
              name='groups'
 | 
			
		||||
              required
 | 
			
		||||
              fluid
 | 
			
		||||
              multiple
 | 
			
		||||
              selection
 | 
			
		||||
@@ -240,6 +262,7 @@ const EditChannel = () => {
 | 
			
		||||
              label='模型'
 | 
			
		||||
              placeholder={'请选择该通道所支持的模型'}
 | 
			
		||||
              name='models'
 | 
			
		||||
              required
 | 
			
		||||
              fluid
 | 
			
		||||
              multiple
 | 
			
		||||
              selection
 | 
			
		||||
@@ -259,11 +282,37 @@ const EditChannel = () => {
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              handleInputChange(null, { name: 'models', value: [] });
 | 
			
		||||
            }}>清除所有模型</Button>
 | 
			
		||||
            <Input
 | 
			
		||||
              action={
 | 
			
		||||
                <Button type={'button'} onClick={()=>{
 | 
			
		||||
                  if (customModel.trim() === "") return;
 | 
			
		||||
                  if (inputs.models.includes(customModel)) return;
 | 
			
		||||
                  let localModels = [...inputs.models];
 | 
			
		||||
                  localModels.push(customModel);
 | 
			
		||||
                  let localModelOptions = [];
 | 
			
		||||
                  localModelOptions.push({
 | 
			
		||||
                    key: customModel,
 | 
			
		||||
                    text: customModel,
 | 
			
		||||
                    value: customModel,
 | 
			
		||||
                  });
 | 
			
		||||
                  setModelOptions(modelOptions=>{
 | 
			
		||||
                    return [...modelOptions, ...localModelOptions];
 | 
			
		||||
                  });
 | 
			
		||||
                  setCustomModel('');
 | 
			
		||||
                  handleInputChange(null, { name: 'models', value: localModels });
 | 
			
		||||
                }}>填入</Button>
 | 
			
		||||
              }
 | 
			
		||||
              placeholder='输入自定义模型名称'
 | 
			
		||||
              value={customModel}
 | 
			
		||||
              onChange={(e, { value }) => {
 | 
			
		||||
                setCustomModel(value);
 | 
			
		||||
              }}
 | 
			
		||||
            />
 | 
			
		||||
          </div>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.TextArea
 | 
			
		||||
              label='模型映射'
 | 
			
		||||
              placeholder={`为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
 | 
			
		||||
              placeholder={`此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
 | 
			
		||||
              name='model_mapping'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.model_mapping}
 | 
			
		||||
@@ -276,6 +325,7 @@ const EditChannel = () => {
 | 
			
		||||
              <Form.TextArea
 | 
			
		||||
                label='密钥'
 | 
			
		||||
                name='key'
 | 
			
		||||
                required
 | 
			
		||||
                placeholder={'请输入密钥,一行一个'}
 | 
			
		||||
                onChange={handleInputChange}
 | 
			
		||||
                value={inputs.key}
 | 
			
		||||
@@ -286,7 +336,8 @@ const EditChannel = () => {
 | 
			
		||||
              <Form.Input
 | 
			
		||||
                label='密钥'
 | 
			
		||||
                name='key'
 | 
			
		||||
                placeholder={'请输入密钥'}
 | 
			
		||||
                required
 | 
			
		||||
                placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入密钥'}
 | 
			
		||||
                onChange={handleInputChange}
 | 
			
		||||
                value={inputs.key}
 | 
			
		||||
                autoComplete='new-password'
 | 
			
		||||
@@ -303,7 +354,7 @@ const EditChannel = () => {
 | 
			
		||||
              />
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Button positive onClick={submit}>提交</Button>
 | 
			
		||||
          <Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
      </Segment>
 | 
			
		||||
    </>
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ const EditToken = () => {
 | 
			
		||||
  const [loading, setLoading] = useState(isEdit);
 | 
			
		||||
  const originInputs = {
 | 
			
		||||
    name: '',
 | 
			
		||||
    remain_quota: 0,
 | 
			
		||||
    remain_quota: isEdit ? 0 : 500000,
 | 
			
		||||
    expired_time: -1,
 | 
			
		||||
    unlimited_quota: false
 | 
			
		||||
  };
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user