mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			26 Commits
		
	
	
		
			v0.3.0-alp
			...
			v0.3.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					3711f4a741 | ||
| 
						 | 
					7c6bf3e97b | ||
| 
						 | 
					481ba41fbd | ||
| 
						 | 
					2779d6629c | ||
| 
						 | 
					e509899daf | ||
| 
						 | 
					b53cdbaf05 | ||
| 
						 | 
					ced89398a5 | ||
| 
						 | 
					09c2e3bcec | ||
| 
						 | 
					5cba800fa6 | ||
| 
						 | 
					2d39a135f2 | ||
| 
						 | 
					3c6834a79c | ||
| 
						 | 
					6da3410823 | ||
| 
						 | 
					ceb289cb4d | ||
| 
						 | 
					6f8cc712b0 | ||
| 
						 | 
					ad01e1f3b3 | ||
| 
						 | 
					cc1ef2ffd5 | ||
| 
						 | 
					7201bd1c97 | ||
| 
						 | 
					73d5e0f283 | ||
| 
						 | 
					efc744ca35 | ||
| 
						 | 
					e8da98139f | ||
| 
						 | 
					519cb030f7 | ||
| 
						 | 
					58fe923c85 | ||
| 
						 | 
					c9ac5e391f | ||
| 
						 | 
					69cf1de7bd | ||
| 
						 | 
					4d6172a242 | ||
| 
						 | 
					8afdc56b11 | 
							
								
								
									
										48
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								README.md
									
									
									
									
									
								
							@@ -38,6 +38,8 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
				
			|||||||
  <a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a>
 | 
					  <a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a>
 | 
				
			||||||
  ·
 | 
					  ·
 | 
				
			||||||
  <a href="https://openai.justsong.cn/">在线演示</a>
 | 
					  <a href="https://openai.justsong.cn/">在线演示</a>
 | 
				
			||||||
 | 
					  ·
 | 
				
			||||||
 | 
					  <a href="https://github.com/songquanpeng/one-api#常见问题">常见问题</a>
 | 
				
			||||||
</p>
 | 
					</p>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。
 | 
					> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。
 | 
				
			||||||
@@ -48,26 +50,28 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
				
			|||||||
   + [x] OpenAI 官方通道
 | 
					   + [x] OpenAI 官方通道
 | 
				
			||||||
   + [x] **Azure OpenAI API**
 | 
					   + [x] **Azure OpenAI API**
 | 
				
			||||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
					   + [x] [API2D](https://api2d.com/r/197971)
 | 
				
			||||||
   + [x] [CloseAI](https://console.openai-asia.com)
 | 
					   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
				
			||||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
					   + [x] [AI.LS](https://ai.ls)
 | 
				
			||||||
   + [x] [OpenAI Max](https://openaimax.com)
 | 
					   + [x] [OpenAI Max](https://openaimax.com)
 | 
				
			||||||
   + [x] [OhMyGPT](https://www.ohmygpt.com)
 | 
					   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
				
			||||||
 | 
					   + [x] [CloseAI](https://console.openai-asia.com)
 | 
				
			||||||
   + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
 | 
					   + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
 | 
				
			||||||
2. 支持通过**负载均衡**的方式访问多个渠道。
 | 
					2. 支持通过**负载均衡**的方式访问多个渠道。
 | 
				
			||||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
					3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
				
			||||||
4. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
					4. 支持**多机部署**,[详见此处](#多机部署)。
 | 
				
			||||||
5. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。
 | 
					5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
				
			||||||
6. 支持**通道管理**,批量创建通道。
 | 
					6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。
 | 
				
			||||||
7. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
					7. 支持**通道管理**,批量创建通道。
 | 
				
			||||||
8. 支持丰富的**自定义**设置,
 | 
					8. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
				
			||||||
 | 
					9. 支持丰富的**自定义**设置,
 | 
				
			||||||
   1. 支持自定义系统名称,logo 以及页脚。
 | 
					   1. 支持自定义系统名称,logo 以及页脚。
 | 
				
			||||||
   2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
					   2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
				
			||||||
9. 支持通过系统访问令牌访问管理 API。
 | 
					10. 支持通过系统访问令牌访问管理 API。
 | 
				
			||||||
10. 支持用户管理,支持**多种用户登录注册方式**:
 | 
					11. 支持用户管理,支持**多种用户登录注册方式**:
 | 
				
			||||||
    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
					    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
				
			||||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
					    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
				
			||||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
					    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
				
			||||||
11. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
					12. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 部署
 | 
					## 部署
 | 
				
			||||||
### 基于 Docker 进行部署
 | 
					### 基于 Docker 进行部署
 | 
				
			||||||
@@ -90,13 +94,10 @@ server{
 | 
				
			|||||||
          proxy_set_header X-Forwarded-For $remote_addr;
 | 
					          proxy_set_header X-Forwarded-For $remote_addr;
 | 
				
			||||||
          proxy_cache_bypass $http_upgrade;
 | 
					          proxy_cache_bypass $http_upgrade;
 | 
				
			||||||
          proxy_set_header Accept-Encoding gzip;
 | 
					          proxy_set_header Accept-Encoding gzip;
 | 
				
			||||||
          proxy_buffering off;  # 重要:关闭代理缓冲
 | 
					 | 
				
			||||||
   }
 | 
					   }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
注意,为了 SSE 正常工作,需要关闭 Nginx 的代理缓冲。
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
之后使用 Let's Encrypt 的 certbot 配置 HTTPS:
 | 
					之后使用 Let's Encrypt 的 certbot 配置 HTTPS:
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
# Ubuntu 安装 certbot:
 | 
					# Ubuntu 安装 certbot:
 | 
				
			||||||
@@ -133,6 +134,14 @@ sudo service nginx restart
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
 | 
					更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 多机部署
 | 
				
			||||||
 | 
					1. 所有服务器 `SESSION_SECRET` 设置一样的值。
 | 
				
			||||||
 | 
					2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,请自行配置主备数据库同步。
 | 
				
			||||||
 | 
					3. 所有从服务器必须设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
 | 
				
			||||||
 | 
					4. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					环境变量的具体使用方法详见[此处](#环境变量)。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 配置
 | 
					## 配置
 | 
				
			||||||
系统本身开箱即用。
 | 
					系统本身开箱即用。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -157,6 +166,10 @@ sudo service nginx restart
 | 
				
			|||||||
   + 例子:`SESSION_SECRET=random_string`
 | 
					   + 例子:`SESSION_SECRET=random_string`
 | 
				
			||||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。
 | 
					3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。
 | 
				
			||||||
   + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api`
 | 
					   + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api`
 | 
				
			||||||
 | 
					4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。
 | 
				
			||||||
 | 
					   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
				
			||||||
 | 
					5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
 | 
				
			||||||
 | 
					   + 例子:`SYNC_FREQUENCY=60`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 命令行参数
 | 
					### 命令行参数
 | 
				
			||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
					1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
				
			||||||
@@ -174,3 +187,10 @@ https://openai.justsong.cn
 | 
				
			|||||||
### 截图展示
 | 
					### 截图展示
 | 
				
			||||||

 | 
					
 | 
				
			||||||

 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 常见问题
 | 
				
			||||||
 | 
					1. 账户额度足够为什么提示额度不足?
 | 
				
			||||||
 | 
					   + 请检查你的令牌额度是否足够,这个和账户额度是分开的。
 | 
				
			||||||
 | 
					   + 令牌额度仅供用户设置最大使用量,用户可自由设置。
 | 
				
			||||||
 | 
					2. 宝塔部署后访问出现空白页面?
 | 
				
			||||||
 | 
					   + 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。
 | 
				
			||||||
@@ -54,6 +54,7 @@ var QuotaForNewUser = 0
 | 
				
			|||||||
var ChannelDisableThreshold = 5.0
 | 
					var ChannelDisableThreshold = 5.0
 | 
				
			||||||
var AutomaticDisableChannelEnabled = false
 | 
					var AutomaticDisableChannelEnabled = false
 | 
				
			||||||
var QuotaRemindThreshold = 1000
 | 
					var QuotaRemindThreshold = 1000
 | 
				
			||||||
 | 
					var PreConsumedQuota = 500
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var RootUserEmail = ""
 | 
					var RootUserEmail = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -126,16 +127,18 @@ const (
 | 
				
			|||||||
	ChannelTypeOpenAIMax = 6
 | 
						ChannelTypeOpenAIMax = 6
 | 
				
			||||||
	ChannelTypeOhMyGPT   = 7
 | 
						ChannelTypeOhMyGPT   = 7
 | 
				
			||||||
	ChannelTypeCustom    = 8
 | 
						ChannelTypeCustom    = 8
 | 
				
			||||||
 | 
						ChannelTypeAILS      = 9
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var ChannelBaseURLs = []string{
 | 
					var ChannelBaseURLs = []string{
 | 
				
			||||||
	"",                            // 0
 | 
						"",                            // 0
 | 
				
			||||||
	"https://api.openai.com",      // 1
 | 
						"https://api.openai.com",      // 1
 | 
				
			||||||
	"https://openai.api2d.net",    // 2
 | 
						"https://oa.api2d.net",        // 2
 | 
				
			||||||
	"",                            // 3
 | 
						"",                            // 3
 | 
				
			||||||
	"https://api.openai-asia.com", // 4
 | 
						"https://api.openai-asia.com", // 4
 | 
				
			||||||
	"https://api.openai-sb.com",   // 5
 | 
						"https://api.openai-sb.com",   // 5
 | 
				
			||||||
	"https://api.openaimax.com",   // 6
 | 
						"https://api.openaimax.com",   // 6
 | 
				
			||||||
	"https://api.ohmygpt.com",     // 7
 | 
						"https://api.ohmygpt.com",     // 7
 | 
				
			||||||
	"",                            // 8
 | 
						"",                            // 8
 | 
				
			||||||
 | 
						"https://api.caipacity.com",   // 9
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -201,7 +201,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if response.Error.Type != "" {
 | 
						if response.Error.Message != "" {
 | 
				
			||||||
		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
							return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
@@ -210,11 +210,12 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
				
			|||||||
func buildTestRequest(c *gin.Context) *ChatRequest {
 | 
					func buildTestRequest(c *gin.Context) *ChatRequest {
 | 
				
			||||||
	model_ := c.Query("model")
 | 
						model_ := c.Query("model")
 | 
				
			||||||
	testRequest := &ChatRequest{
 | 
						testRequest := &ChatRequest{
 | 
				
			||||||
		Model: model_,
 | 
							Model:     model_,
 | 
				
			||||||
 | 
							MaxTokens: 1,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	testMessage := Message{
 | 
						testMessage := Message{
 | 
				
			||||||
		Role:    "user",
 | 
							Role:    "user",
 | 
				
			||||||
		Content: "echo hi",
 | 
							Content: "hi",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
						testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
				
			||||||
	return testRequest
 | 
						return testRequest
 | 
				
			||||||
@@ -264,14 +265,14 @@ var testAllChannelsLock sync.Mutex
 | 
				
			|||||||
var testAllChannelsRunning bool = false
 | 
					var testAllChannelsRunning bool = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// disable & notify
 | 
					// disable & notify
 | 
				
			||||||
func disableChannel(channelId int, channelName string, err error) {
 | 
					func disableChannel(channelId int, channelName string, reason string) {
 | 
				
			||||||
	if common.RootUserEmail == "" {
 | 
						if common.RootUserEmail == "" {
 | 
				
			||||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
							common.RootUserEmail = model.GetRootUserEmail()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
						model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
				
			||||||
	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
						subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
				
			||||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
 | 
						content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | 
				
			||||||
	err = common.SendEmail(subject, common.RootUserEmail, content)
 | 
						err := common.SendEmail(subject, common.RootUserEmail, content)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
							common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -311,7 +312,7 @@ func testAllChannels(c *gin.Context) error {
 | 
				
			|||||||
				if milliseconds > disableThreshold {
 | 
									if milliseconds > disableThreshold {
 | 
				
			||||||
					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
										err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				disableChannel(channel.Id, channel.Name, err)
 | 
									disableChannel(channel.Id, channel.Name, err.Error())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			channel.UpdateResponseTime(milliseconds)
 | 
								channel.UpdateResponseTime(milliseconds)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										153
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,153 @@
 | 
				
			|||||||
 | 
					package controller
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// https://platform.openai.com/docs/api-reference/models/list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OpenAIModelPermission struct {
 | 
				
			||||||
 | 
						Id                 string  `json:"id"`
 | 
				
			||||||
 | 
						Object             string  `json:"object"`
 | 
				
			||||||
 | 
						Created            int     `json:"created"`
 | 
				
			||||||
 | 
						AllowCreateEngine  bool    `json:"allow_create_engine"`
 | 
				
			||||||
 | 
						AllowSampling      bool    `json:"allow_sampling"`
 | 
				
			||||||
 | 
						AllowLogprobs      bool    `json:"allow_logprobs"`
 | 
				
			||||||
 | 
						AllowSearchIndices bool    `json:"allow_search_indices"`
 | 
				
			||||||
 | 
						AllowView          bool    `json:"allow_view"`
 | 
				
			||||||
 | 
						AllowFineTuning    bool    `json:"allow_fine_tuning"`
 | 
				
			||||||
 | 
						Organization       string  `json:"organization"`
 | 
				
			||||||
 | 
						Group              *string `json:"group"`
 | 
				
			||||||
 | 
						IsBlocking         bool    `json:"is_blocking"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OpenAIModels struct {
 | 
				
			||||||
 | 
						Id         string                `json:"id"`
 | 
				
			||||||
 | 
						Object     string                `json:"object"`
 | 
				
			||||||
 | 
						Created    int                   `json:"created"`
 | 
				
			||||||
 | 
						OwnedBy    string                `json:"owned_by"`
 | 
				
			||||||
 | 
						Permission OpenAIModelPermission `json:"permission"`
 | 
				
			||||||
 | 
						Root       string                `json:"root"`
 | 
				
			||||||
 | 
						Parent     *string               `json:"parent"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var openAIModels []OpenAIModels
 | 
				
			||||||
 | 
					var openAIModelsMap map[string]OpenAIModels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						permission := OpenAIModelPermission{
 | 
				
			||||||
 | 
							Id:                 "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
 | 
				
			||||||
 | 
							Object:             "model_permission",
 | 
				
			||||||
 | 
							Created:            1626777600,
 | 
				
			||||||
 | 
							AllowCreateEngine:  true,
 | 
				
			||||||
 | 
							AllowSampling:      true,
 | 
				
			||||||
 | 
							AllowLogprobs:      true,
 | 
				
			||||||
 | 
							AllowSearchIndices: false,
 | 
				
			||||||
 | 
							AllowView:          true,
 | 
				
			||||||
 | 
							AllowFineTuning:    false,
 | 
				
			||||||
 | 
							Organization:       "*",
 | 
				
			||||||
 | 
							Group:              nil,
 | 
				
			||||||
 | 
							IsBlocking:         false,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
				
			||||||
 | 
						openAIModels = []OpenAIModels{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-3.5-turbo",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-3.5-turbo",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-3.5-turbo-0301",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-3.5-turbo-0301",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-4",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-4",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-4-0314",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-4-0314",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-4-32k",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-4-32k",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-4-32k-0314",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-4-32k-0314",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "gpt-3.5-turbo",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "gpt-3.5-turbo",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Id:         "text-embedding-ada-002",
 | 
				
			||||||
 | 
								Object:     "model",
 | 
				
			||||||
 | 
								Created:    1677649963,
 | 
				
			||||||
 | 
								OwnedBy:    "openai",
 | 
				
			||||||
 | 
								Permission: permission,
 | 
				
			||||||
 | 
								Root:       "text-embedding-ada-002",
 | 
				
			||||||
 | 
								Parent:     nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						openAIModelsMap = make(map[string]OpenAIModels)
 | 
				
			||||||
 | 
						for _, model := range openAIModels {
 | 
				
			||||||
 | 
							openAIModelsMap[model.Id] = model
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ListModels(c *gin.Context) {
 | 
				
			||||||
 | 
						c.JSON(200, openAIModels)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func RetrieveModel(c *gin.Context) {
 | 
				
			||||||
 | 
						modelId := c.Param("model")
 | 
				
			||||||
 | 
						if model, ok := openAIModelsMap[modelId]; ok {
 | 
				
			||||||
 | 
							c.JSON(200, model)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							openAIError := OpenAIError{
 | 
				
			||||||
 | 
								Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 | 
				
			||||||
 | 
								Type:    "invalid_request_error",
 | 
				
			||||||
 | 
								Param:   "model",
 | 
				
			||||||
 | 
								Code:    "model_not_found",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
 | 
								"error": openAIError,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										61
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					package controller
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/pkoukk/tiktoken-go"
 | 
				
			||||||
 | 
						"one-api/common"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
				
			||||||
 | 
						if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
				
			||||||
 | 
							return tokenEncoder
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error()))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tokenEncoderMap[model] = tokenEncoder
 | 
				
			||||||
 | 
						return tokenEncoder
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func countTokenMessages(messages []Message, model string) int {
 | 
				
			||||||
 | 
						tokenEncoder := getTokenEncoder(model)
 | 
				
			||||||
 | 
						// Reference:
 | 
				
			||||||
 | 
						// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
				
			||||||
 | 
						// https://github.com/pkoukk/tiktoken-go/issues/6
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// Every message follows <|start|>{role/name}\n{content}<|end|>\n
 | 
				
			||||||
 | 
						var tokensPerMessage int
 | 
				
			||||||
 | 
						var tokensPerName int
 | 
				
			||||||
 | 
						if strings.HasPrefix(model, "gpt-3.5") {
 | 
				
			||||||
 | 
							tokensPerMessage = 4
 | 
				
			||||||
 | 
							tokensPerName = -1 // If there's a name, the role is omitted
 | 
				
			||||||
 | 
						} else if strings.HasPrefix(model, "gpt-4") {
 | 
				
			||||||
 | 
							tokensPerMessage = 3
 | 
				
			||||||
 | 
							tokensPerName = 1
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							tokensPerMessage = 3
 | 
				
			||||||
 | 
							tokensPerName = 1
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tokenNum := 0
 | 
				
			||||||
 | 
						for _, message := range messages {
 | 
				
			||||||
 | 
							tokenNum += tokensPerMessage
 | 
				
			||||||
 | 
							tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
 | 
				
			||||||
 | 
							tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
 | 
				
			||||||
 | 
							if message.Name != nil {
 | 
				
			||||||
 | 
								tokenNum += tokensPerName
 | 
				
			||||||
 | 
								tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
 | 
				
			||||||
 | 
						return tokenNum
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func countTokenText(text string, model string) int {
 | 
				
			||||||
 | 
						tokenEncoder := getTokenEncoder(model)
 | 
				
			||||||
 | 
						token := tokenEncoder.Encode(text, nil, nil)
 | 
				
			||||||
 | 
						return len(token)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -4,10 +4,8 @@ import (
 | 
				
			|||||||
	"bufio"
 | 
						"bufio"
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/pkoukk/tiktoken-go"
 | 
					 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
@@ -16,19 +14,22 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Message struct {
 | 
					type Message struct {
 | 
				
			||||||
	Role    string `json:"role"`
 | 
						Role    string  `json:"role"`
 | 
				
			||||||
	Content string `json:"content"`
 | 
						Content string  `json:"content"`
 | 
				
			||||||
 | 
						Name    *string `json:"name,omitempty"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ChatRequest struct {
 | 
					type ChatRequest struct {
 | 
				
			||||||
	Model    string    `json:"model"`
 | 
						Model     string    `json:"model"`
 | 
				
			||||||
	Messages []Message `json:"messages"`
 | 
						Messages  []Message `json:"messages"`
 | 
				
			||||||
 | 
						MaxTokens int       `json:"max_tokens"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TextRequest struct {
 | 
					type TextRequest struct {
 | 
				
			||||||
	Model    string    `json:"model"`
 | 
						Model     string    `json:"model"`
 | 
				
			||||||
	Messages []Message `json:"messages"`
 | 
						Messages  []Message `json:"messages"`
 | 
				
			||||||
	Prompt   string    `json:"prompt"`
 | 
						Prompt    string    `json:"prompt"`
 | 
				
			||||||
 | 
						MaxTokens int       `json:"max_tokens"`
 | 
				
			||||||
	//Stream   bool      `json:"stream"`
 | 
						//Stream   bool      `json:"stream"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -45,6 +46,11 @@ type OpenAIError struct {
 | 
				
			|||||||
	Code    string `json:"code"`
 | 
						Code    string `json:"code"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OpenAIErrorWithStatusCode struct {
 | 
				
			||||||
 | 
						OpenAIError
 | 
				
			||||||
 | 
						StatusCode int `json:"status_code"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TextResponse struct {
 | 
					type TextResponse struct {
 | 
				
			||||||
	Usage `json:"usage"`
 | 
						Usage `json:"usage"`
 | 
				
			||||||
	Error OpenAIError `json:"error"`
 | 
						Error OpenAIError `json:"error"`
 | 
				
			||||||
@@ -59,31 +65,39 @@ type StreamResponse struct {
 | 
				
			|||||||
	} `json:"choices"`
 | 
						} `json:"choices"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func countToken(text string) int {
 | 
					 | 
				
			||||||
	token := tokenEncoder.Encode(text, nil, nil)
 | 
					 | 
				
			||||||
	return len(token)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func Relay(c *gin.Context) {
 | 
					func Relay(c *gin.Context) {
 | 
				
			||||||
	err := relayHelper(c)
 | 
						err := relayHelper(c)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		c.JSON(http.StatusOK, gin.H{
 | 
							if err.StatusCode == http.StatusTooManyRequests {
 | 
				
			||||||
			"error": gin.H{
 | 
								err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
 | 
				
			||||||
				"message": err.Error(),
 | 
							}
 | 
				
			||||||
				"type":    "one_api_error",
 | 
							c.JSON(err.StatusCode, gin.H{
 | 
				
			||||||
			},
 | 
								"error": err.OpenAIError,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
		if common.AutomaticDisableChannelEnabled {
 | 
							channelId := c.GetInt("channel_id")
 | 
				
			||||||
 | 
							common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
 | 
				
			||||||
 | 
							if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests &&
 | 
				
			||||||
 | 
								common.AutomaticDisableChannelEnabled {
 | 
				
			||||||
			channelId := c.GetInt("channel_id")
 | 
								channelId := c.GetInt("channel_id")
 | 
				
			||||||
			channelName := c.GetString("channel_name")
 | 
								channelName := c.GetString("channel_name")
 | 
				
			||||||
			disableChannel(channelId, channelName, err)
 | 
								disableChannel(channelId, channelName, err.Message)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func relayHelper(c *gin.Context) error {
 | 
					func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
 | 
				
			||||||
 | 
						openAIError := OpenAIError{
 | 
				
			||||||
 | 
							Message: err.Error(),
 | 
				
			||||||
 | 
							Type:    "one_api_error",
 | 
				
			||||||
 | 
							Code:    code,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &OpenAIErrorWithStatusCode{
 | 
				
			||||||
 | 
							OpenAIError: openAIError,
 | 
				
			||||||
 | 
							StatusCode:  statusCode,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
				
			||||||
	channelType := c.GetInt("channel")
 | 
						channelType := c.GetInt("channel")
 | 
				
			||||||
	tokenId := c.GetInt("token_id")
 | 
						tokenId := c.GetInt("token_id")
 | 
				
			||||||
	consumeQuota := c.GetBool("consume_quota")
 | 
						consumeQuota := c.GetBool("consume_quota")
 | 
				
			||||||
@@ -91,15 +105,15 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
	if consumeQuota || channelType == common.ChannelTypeAzure {
 | 
						if consumeQuota || channelType == common.ChannelTypeAzure {
 | 
				
			||||||
		requestBody, err := io.ReadAll(c.Request.Body)
 | 
							requestBody, err := io.ReadAll(c.Request.Body)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = c.Request.Body.Close()
 | 
							err = c.Request.Body.Close()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = json.Unmarshal(requestBody, &textRequest)
 | 
							err = json.Unmarshal(requestBody, &textRequest)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		// Reset request body
 | 
							// Reset request body
 | 
				
			||||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
							c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
				
			||||||
@@ -128,9 +142,23 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
		model_ = strings.TrimSuffix(model_, "-0314")
 | 
							model_ = strings.TrimSuffix(model_, "-0314")
 | 
				
			||||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 | 
							fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
 | 
				
			||||||
 | 
						preConsumedTokens := common.PreConsumedQuota
 | 
				
			||||||
 | 
						if textRequest.MaxTokens != 0 {
 | 
				
			||||||
 | 
							preConsumedTokens = promptTokens + textRequest.MaxTokens
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ratio := common.GetModelRatio(textRequest.Model)
 | 
				
			||||||
 | 
						preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
				
			||||||
 | 
						if consumeQuota {
 | 
				
			||||||
 | 
							err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 | 
						req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return errorWrapper(err, "new_request_failed", http.StatusOK)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if channelType == common.ChannelTypeAzure {
 | 
						if channelType == common.ChannelTypeAzure {
 | 
				
			||||||
		key := c.Request.Header.Get("Authorization")
 | 
							key := c.Request.Header.Get("Authorization")
 | 
				
			||||||
@@ -145,18 +173,18 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
	client := &http.Client{}
 | 
						client := &http.Client{}
 | 
				
			||||||
	resp, err := client.Do(req)
 | 
						resp, err := client.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return errorWrapper(err, "do_request_failed", http.StatusOK)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = req.Body.Close()
 | 
						err = req.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = c.Request.Body.Close()
 | 
						err = c.Request.Body.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var textResponse TextResponse
 | 
						var textResponse TextResponse
 | 
				
			||||||
	isStream := resp.Header.Get("Content-Type") == "text/event-stream"
 | 
						isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
				
			||||||
	var streamResponseText string
 | 
						var streamResponseText string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	defer func() {
 | 
						defer func() {
 | 
				
			||||||
@@ -168,18 +196,14 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
				completionRatio = 2
 | 
									completionRatio = 2
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if isStream {
 | 
								if isStream {
 | 
				
			||||||
				var promptText string
 | 
									responseTokens := countTokenText(streamResponseText, textRequest.Model)
 | 
				
			||||||
				for _, message := range textRequest.Messages {
 | 
									quota = promptTokens + responseTokens*completionRatio
 | 
				
			||||||
					promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
 | 
					 | 
				
			||||||
				quota = countToken(promptText) + countToken(completionText)*completionRatio + 3
 | 
					 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 | 
									quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			ratio := common.GetModelRatio(textRequest.Model)
 | 
					 | 
				
			||||||
			quota = int(float64(quota) * ratio)
 | 
								quota = int(float64(quota) * ratio)
 | 
				
			||||||
			err := model.DecreaseTokenQuota(tokenId, quota)
 | 
								quotaDelta := quota - preConsumedQuota
 | 
				
			||||||
 | 
								err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				common.SysError("Error consuming token remain quota: " + err.Error())
 | 
									common.SysError("Error consuming token remain quota: " + err.Error())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -208,6 +232,10 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
			for scanner.Scan() {
 | 
								for scanner.Scan() {
 | 
				
			||||||
				data := scanner.Text()
 | 
									data := scanner.Text()
 | 
				
			||||||
 | 
									if len(data) < 6 { // must be something wrong!
 | 
				
			||||||
 | 
										common.SysError("Invalid stream response: " + data)
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
				dataChan <- data
 | 
									dataChan <- data
 | 
				
			||||||
				data = data[6:]
 | 
									data = data[6:]
 | 
				
			||||||
				if !strings.HasPrefix(data, "[DONE]") {
 | 
									if !strings.HasPrefix(data, "[DONE]") {
 | 
				
			||||||
@@ -228,6 +256,7 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
		c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
							c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
				
			||||||
		c.Writer.Header().Set("Connection", "keep-alive")
 | 
							c.Writer.Header().Set("Connection", "keep-alive")
 | 
				
			||||||
		c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
							c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
				
			||||||
 | 
							c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
				
			||||||
		c.Stream(func(w io.Writer) bool {
 | 
							c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
			select {
 | 
								select {
 | 
				
			||||||
			case data := <-dataChan:
 | 
								case data := <-dataChan:
 | 
				
			||||||
@@ -242,50 +271,60 @@ func relayHelper(c *gin.Context) error {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
		err = resp.Body.Close()
 | 
							err = resp.Body.Close()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		for k, v := range resp.Header {
 | 
					 | 
				
			||||||
			c.Writer.Header().Set(k, v[0])
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if consumeQuota {
 | 
							if consumeQuota {
 | 
				
			||||||
			responseBody, err := io.ReadAll(resp.Body)
 | 
								responseBody, err := io.ReadAll(resp.Body)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return errorWrapper(err, "read_response_body_failed", http.StatusOK)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			err = resp.Body.Close()
 | 
								err = resp.Body.Close()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			err = json.Unmarshal(responseBody, &textResponse)
 | 
								err = json.Unmarshal(responseBody, &textResponse)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if textResponse.Error.Type != "" {
 | 
								if textResponse.Error.Type != "" {
 | 
				
			||||||
				return errors.New(fmt.Sprintf("type %s, code %s, message %s",
 | 
									return &OpenAIErrorWithStatusCode{
 | 
				
			||||||
					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
 | 
										OpenAIError: textResponse.Error,
 | 
				
			||||||
 | 
										StatusCode:  resp.StatusCode,
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			// Reset response body
 | 
								// Reset response body
 | 
				
			||||||
			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
								resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
				
			||||||
 | 
							// And then we will have to send an error response, but in this case, the header has already been set.
 | 
				
			||||||
 | 
							// So the client will be confused by the response.
 | 
				
			||||||
 | 
							// For example, Postman will report error, and we cannot check the response at all.
 | 
				
			||||||
 | 
							for k, v := range resp.Header {
 | 
				
			||||||
 | 
								c.Writer.Header().Set(k, v[0])
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							c.Writer.WriteHeader(resp.StatusCode)
 | 
				
			||||||
		_, err = io.Copy(c.Writer, resp.Body)
 | 
							_, err = io.Copy(c.Writer, resp.Body)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = resp.Body.Close()
 | 
							err = resp.Body.Close()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RelayNotImplemented(c *gin.Context) {
 | 
					func RelayNotImplemented(c *gin.Context) {
 | 
				
			||||||
 | 
						err := OpenAIError{
 | 
				
			||||||
 | 
							Message: "API not implemented",
 | 
				
			||||||
 | 
							Type:    "one_api_error",
 | 
				
			||||||
 | 
							Param:   "",
 | 
				
			||||||
 | 
							Code:    "api_not_implemented",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	c.JSON(http.StatusOK, gin.H{
 | 
						c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
		"error": gin.H{
 | 
							"error": err,
 | 
				
			||||||
			"message": "Not Implemented",
 | 
					 | 
				
			||||||
			"type":    "one_api_error",
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -467,6 +467,13 @@ func CreateUser(c *gin.Context) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if err := common.Validate.Struct(&user); err != nil {
 | 
				
			||||||
 | 
							c.JSON(http.StatusOK, gin.H{
 | 
				
			||||||
 | 
								"success": false,
 | 
				
			||||||
 | 
								"message": "输入不合法 " + err.Error(),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	if user.DisplayName == "" {
 | 
						if user.DisplayName == "" {
 | 
				
			||||||
		user.DisplayName = user.Username
 | 
							user.DisplayName = user.Username
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										7
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.go
									
									
									
									
									
								
							@@ -47,6 +47,13 @@ func main() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Initialize options
 | 
						// Initialize options
 | 
				
			||||||
	model.InitOptionMap()
 | 
						model.InitOptionMap()
 | 
				
			||||||
 | 
						if os.Getenv("SYNC_FREQUENCY") != "" {
 | 
				
			||||||
 | 
							frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								common.FatalLog(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							go model.SyncOptions(frequency)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Initialize HTTP server
 | 
						// Initialize HTTP server
 | 
				
			||||||
	server := gin.Default()
 | 
						server := gin.Default()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -85,6 +85,8 @@ func RootAuth() func(c *gin.Context) {
 | 
				
			|||||||
func TokenAuth() func(c *gin.Context) {
 | 
					func TokenAuth() func(c *gin.Context) {
 | 
				
			||||||
	return func(c *gin.Context) {
 | 
						return func(c *gin.Context) {
 | 
				
			||||||
		key := c.Request.Header.Get("Authorization")
 | 
							key := c.Request.Header.Get("Authorization")
 | 
				
			||||||
 | 
							key = strings.TrimPrefix(key, "Bearer ")
 | 
				
			||||||
 | 
							key = strings.TrimPrefix(key, "sk-")
 | 
				
			||||||
		parts := strings.Split(key, "-")
 | 
							parts := strings.Split(key, "-")
 | 
				
			||||||
		key = parts[0]
 | 
							key = parts[0]
 | 
				
			||||||
		token, err := model.ValidateUserToken(key)
 | 
							token, err := model.ValidateUserToken(key)
 | 
				
			||||||
@@ -111,7 +113,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
				
			|||||||
		c.Set("id", token.UserId)
 | 
							c.Set("id", token.UserId)
 | 
				
			||||||
		c.Set("token_id", token.Id)
 | 
							c.Set("token_id", token.Id)
 | 
				
			||||||
		requestURL := c.Request.URL.String()
 | 
							requestURL := c.Request.URL.String()
 | 
				
			||||||
		consumeQuota := !token.UnlimitedQuota
 | 
							consumeQuota := true
 | 
				
			||||||
		if strings.HasPrefix(requestURL, "/v1/models") {
 | 
							if strings.HasPrefix(requestURL, "/v1/models") {
 | 
				
			||||||
			consumeQuota = false
 | 
								consumeQuota = false
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import (
 | 
				
			|||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Option struct {
 | 
					type Option struct {
 | 
				
			||||||
@@ -55,9 +56,14 @@ func InitOptionMap() {
 | 
				
			|||||||
	common.OptionMap["TurnstileSecretKey"] = ""
 | 
						common.OptionMap["TurnstileSecretKey"] = ""
 | 
				
			||||||
	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
 | 
						common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
 | 
				
			||||||
	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 | 
						common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 | 
				
			||||||
 | 
						common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 | 
				
			||||||
	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 | 
						common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 | 
				
			||||||
	common.OptionMap["TopUpLink"] = common.TopUpLink
 | 
						common.OptionMap["TopUpLink"] = common.TopUpLink
 | 
				
			||||||
	common.OptionMapRWMutex.Unlock()
 | 
						common.OptionMapRWMutex.Unlock()
 | 
				
			||||||
 | 
						loadOptionsFromDatabase()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func loadOptionsFromDatabase() {
 | 
				
			||||||
	options, _ := AllOption()
 | 
						options, _ := AllOption()
 | 
				
			||||||
	for _, option := range options {
 | 
						for _, option := range options {
 | 
				
			||||||
		err := updateOptionMap(option.Key, option.Value)
 | 
							err := updateOptionMap(option.Key, option.Value)
 | 
				
			||||||
@@ -67,6 +73,14 @@ func InitOptionMap() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SyncOptions(frequency int) {
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							time.Sleep(time.Duration(frequency) * time.Second)
 | 
				
			||||||
 | 
							common.SysLog("Syncing options from database")
 | 
				
			||||||
 | 
							loadOptionsFromDatabase()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func UpdateOption(key string, value string) error {
 | 
					func UpdateOption(key string, value string) error {
 | 
				
			||||||
	// Save to database first
 | 
						// Save to database first
 | 
				
			||||||
	option := Option{
 | 
						option := Option{
 | 
				
			||||||
@@ -159,6 +173,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
				
			|||||||
		common.QuotaForNewUser, _ = strconv.Atoi(value)
 | 
							common.QuotaForNewUser, _ = strconv.Atoi(value)
 | 
				
			||||||
	case "QuotaRemindThreshold":
 | 
						case "QuotaRemindThreshold":
 | 
				
			||||||
		common.QuotaRemindThreshold, _ = strconv.Atoi(value)
 | 
							common.QuotaRemindThreshold, _ = strconv.Atoi(value)
 | 
				
			||||||
 | 
						case "PreConsumedQuota":
 | 
				
			||||||
 | 
							common.PreConsumedQuota, _ = strconv.Atoi(value)
 | 
				
			||||||
	case "ModelRatio":
 | 
						case "ModelRatio":
 | 
				
			||||||
		err = common.UpdateModelRatioByJSONString(value)
 | 
							err = common.UpdateModelRatioByJSONString(value)
 | 
				
			||||||
	case "TopUpLink":
 | 
						case "TopUpLink":
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,7 +6,6 @@ import (
 | 
				
			|||||||
	_ "gorm.io/driver/sqlite"
 | 
						_ "gorm.io/driver/sqlite"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"one-api/common"
 | 
						"one-api/common"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Token struct {
 | 
					type Token struct {
 | 
				
			||||||
@@ -38,7 +37,6 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
				
			|||||||
	if key == "" {
 | 
						if key == "" {
 | 
				
			||||||
		return nil, errors.New("未提供 token")
 | 
							return nil, errors.New("未提供 token")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	key = strings.Replace(key, "Bearer ", "", 1)
 | 
					 | 
				
			||||||
	token = &Token{}
 | 
						token = &Token{}
 | 
				
			||||||
	err = DB.Where("`key` = ?", key).First(token).Error
 | 
						err = DB.Where("`key` = ?", key).First(token).Error
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
@@ -130,7 +128,23 @@ func DeleteTokenById(id int, userId int) (err error) {
 | 
				
			|||||||
	return token.Delete()
 | 
						return token.Delete()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 | 
					func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
				
			||||||
 | 
						if quota < 0 {
 | 
				
			||||||
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
				
			||||||
 | 
						if quota < 0 {
 | 
				
			||||||
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
				
			||||||
	if quota < 0 {
 | 
						if quota < 0 {
 | 
				
			||||||
		return errors.New("quota 不能为负数!")
 | 
							return errors.New("quota 不能为负数!")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -138,7 +152,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if token.RemainQuota < quota {
 | 
						if !token.UnlimitedQuota && token.RemainQuota < quota {
 | 
				
			||||||
		return errors.New("令牌额度不足")
 | 
							return errors.New("令牌额度不足")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	userQuota, err := GetUserQuota(token.UserId)
 | 
						userQuota, err := GetUserQuota(token.UserId)
 | 
				
			||||||
@@ -163,17 +177,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 | 
				
			|||||||
			if email != "" {
 | 
								if email != "" {
 | 
				
			||||||
				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
 | 
									topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
 | 
				
			||||||
				err = common.SendEmail(prompt, email,
 | 
									err = common.SendEmail(prompt, email,
 | 
				
			||||||
					fmt.Sprintf("%s,剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota-quota, topUpLink, topUpLink))
 | 
										fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					common.SysError("发送邮件失败:" + err.Error())
 | 
										common.SysError("发送邮件失败:" + err.Error())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
 | 
						if !token.UnlimitedQuota {
 | 
				
			||||||
	if err != nil {
 | 
							err = DecreaseTokenQuota(tokenId, quota)
 | 
				
			||||||
		return err
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = DecreaseUserQuota(token.UserId, quota)
 | 
						err = DecreaseUserQuota(token.UserId, quota)
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
				
			||||||
 | 
						token, err := GetTokenById(tokenId)
 | 
				
			||||||
 | 
						if quota > 0 {
 | 
				
			||||||
 | 
							err = DecreaseUserQuota(token.UserId, quota)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							err = IncreaseUserQuota(token.UserId, -quota)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !token.UnlimitedQuota {
 | 
				
			||||||
 | 
							if quota > 0 {
 | 
				
			||||||
 | 
								err = DecreaseTokenQuota(tokenId, quota)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								err = IncreaseTokenQuota(tokenId, -quota)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,12 +2,24 @@ package router
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"embed"
 | 
						"embed"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
					func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
				
			||||||
	SetApiRouter(router)
 | 
						SetApiRouter(router)
 | 
				
			||||||
	SetDashboardRouter(router)
 | 
						SetDashboardRouter(router)
 | 
				
			||||||
	SetRelayRouter(router)
 | 
						SetRelayRouter(router)
 | 
				
			||||||
	setWebRouter(router, buildFS, indexPage)
 | 
						frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
 | 
				
			||||||
 | 
						if frontendBaseUrl == "" {
 | 
				
			||||||
 | 
							SetWebRouter(router, buildFS, indexPage)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
 | 
				
			||||||
 | 
							router.NoRoute(func(c *gin.Context) {
 | 
				
			||||||
 | 
								c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,8 +11,8 @@ func SetRelayRouter(router *gin.Engine) {
 | 
				
			|||||||
	relayV1Router := router.Group("/v1")
 | 
						relayV1Router := router.Group("/v1")
 | 
				
			||||||
	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 | 
						relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		relayV1Router.GET("/models", controller.Relay)
 | 
							relayV1Router.GET("/models", controller.ListModels)
 | 
				
			||||||
		relayV1Router.GET("/models/:model", controller.Relay)
 | 
							relayV1Router.GET("/models/:model", controller.RetrieveModel)
 | 
				
			||||||
		relayV1Router.POST("/completions", controller.RelayNotImplemented)
 | 
							relayV1Router.POST("/completions", controller.RelayNotImplemented)
 | 
				
			||||||
		relayV1Router.POST("/chat/completions", controller.Relay)
 | 
							relayV1Router.POST("/chat/completions", controller.Relay)
 | 
				
			||||||
		relayV1Router.POST("/edits", controller.RelayNotImplemented)
 | 
							relayV1Router.POST("/edits", controller.RelayNotImplemented)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,7 +10,7 @@ import (
 | 
				
			|||||||
	"one-api/middleware"
 | 
						"one-api/middleware"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func setWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
					func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
				
			||||||
	router.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
						router.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
				
			||||||
	router.Use(middleware.GlobalWebRateLimit())
 | 
						router.Use(middleware.GlobalWebRateLimit())
 | 
				
			||||||
	router.Use(middleware.Cache())
 | 
						router.Use(middleware.Cache())
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,6 +28,7 @@ const SystemSetting = () => {
 | 
				
			|||||||
    RegisterEnabled: '',
 | 
					    RegisterEnabled: '',
 | 
				
			||||||
    QuotaForNewUser: 0,
 | 
					    QuotaForNewUser: 0,
 | 
				
			||||||
    QuotaRemindThreshold: 0,
 | 
					    QuotaRemindThreshold: 0,
 | 
				
			||||||
 | 
					    PreConsumedQuota: 0,
 | 
				
			||||||
    ModelRatio: '',
 | 
					    ModelRatio: '',
 | 
				
			||||||
    TopUpLink: '',
 | 
					    TopUpLink: '',
 | 
				
			||||||
    AutomaticDisableChannelEnabled: '',
 | 
					    AutomaticDisableChannelEnabled: '',
 | 
				
			||||||
@@ -98,6 +99,7 @@ const SystemSetting = () => {
 | 
				
			|||||||
      name === 'TurnstileSecretKey' ||
 | 
					      name === 'TurnstileSecretKey' ||
 | 
				
			||||||
      name === 'QuotaForNewUser' ||
 | 
					      name === 'QuotaForNewUser' ||
 | 
				
			||||||
      name === 'QuotaRemindThreshold' ||
 | 
					      name === 'QuotaRemindThreshold' ||
 | 
				
			||||||
 | 
					      name === 'PreConsumedQuota' ||
 | 
				
			||||||
      name === 'ModelRatio' ||
 | 
					      name === 'ModelRatio' ||
 | 
				
			||||||
      name === 'TopUpLink'
 | 
					      name === 'TopUpLink'
 | 
				
			||||||
    ) {
 | 
					    ) {
 | 
				
			||||||
@@ -119,6 +121,9 @@ const SystemSetting = () => {
 | 
				
			|||||||
    if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
 | 
					    if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
 | 
				
			||||||
      await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
 | 
					      await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
 | 
				
			||||||
 | 
					      await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
 | 
					    if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
 | 
				
			||||||
      if (!verifyJSON(inputs.ModelRatio)) {
 | 
					      if (!verifyJSON(inputs.ModelRatio)) {
 | 
				
			||||||
        showError('模型倍率不是合法的 JSON 字符串');
 | 
					        showError('模型倍率不是合法的 JSON 字符串');
 | 
				
			||||||
@@ -272,7 +277,7 @@ const SystemSetting = () => {
 | 
				
			|||||||
          <Header as='h3'>
 | 
					          <Header as='h3'>
 | 
				
			||||||
            运营设置
 | 
					            运营设置
 | 
				
			||||||
          </Header>
 | 
					          </Header>
 | 
				
			||||||
          <Form.Group widths={3}>
 | 
					          <Form.Group widths={4}>
 | 
				
			||||||
            <Form.Input
 | 
					            <Form.Input
 | 
				
			||||||
              label='新用户初始配额'
 | 
					              label='新用户初始配额'
 | 
				
			||||||
              name='QuotaForNewUser'
 | 
					              name='QuotaForNewUser'
 | 
				
			||||||
@@ -302,6 +307,16 @@ const SystemSetting = () => {
 | 
				
			|||||||
              min='0'
 | 
					              min='0'
 | 
				
			||||||
              placeholder='低于此额度时将发送邮件提醒用户'
 | 
					              placeholder='低于此额度时将发送邮件提醒用户'
 | 
				
			||||||
            />
 | 
					            />
 | 
				
			||||||
 | 
					            <Form.Input
 | 
				
			||||||
 | 
					              label='请求预扣费额度'
 | 
				
			||||||
 | 
					              name='PreConsumedQuota'
 | 
				
			||||||
 | 
					              onChange={handleInputChange}
 | 
				
			||||||
 | 
					              autoComplete='new-password'
 | 
				
			||||||
 | 
					              value={inputs.PreConsumedQuota}
 | 
				
			||||||
 | 
					              type='number'
 | 
				
			||||||
 | 
					              min='0'
 | 
				
			||||||
 | 
					              placeholder='请求结束后多退少补'
 | 
				
			||||||
 | 
					            />
 | 
				
			||||||
          </Form.Group>
 | 
					          </Form.Group>
 | 
				
			||||||
          <Form.Group widths='equal'>
 | 
					          <Form.Group widths='equal'>
 | 
				
			||||||
            <Form.TextArea
 | 
					            <Form.TextArea
 | 
				
			||||||
@@ -321,7 +336,7 @@ const SystemSetting = () => {
 | 
				
			|||||||
          </Header>
 | 
					          </Header>
 | 
				
			||||||
          <Form.Group widths={3}>
 | 
					          <Form.Group widths={3}>
 | 
				
			||||||
            <Form.Input
 | 
					            <Form.Input
 | 
				
			||||||
              label='最长回应时间'
 | 
					              label='最长响应时间'
 | 
				
			||||||
              name='ChannelDisableThreshold'
 | 
					              name='ChannelDisableThreshold'
 | 
				
			||||||
              onChange={handleInputChange}
 | 
					              onChange={handleInputChange}
 | 
				
			||||||
              autoComplete='new-password'
 | 
					              autoComplete='new-password'
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
 | 
				
			|||||||
import { API, showError, showSuccess } from '../helpers';
 | 
					import { API, showError, showSuccess } from '../helpers';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import { ITEMS_PER_PAGE } from '../constants';
 | 
					import { ITEMS_PER_PAGE } from '../constants';
 | 
				
			||||||
 | 
					import { renderText } from '../helpers/render';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
function renderRole(role) {
 | 
					function renderRole(role) {
 | 
				
			||||||
  switch (role) {
 | 
					  switch (role) {
 | 
				
			||||||
@@ -64,7 +65,7 @@ const UsersTable = () => {
 | 
				
			|||||||
    (async () => {
 | 
					    (async () => {
 | 
				
			||||||
      const res = await API.post('/api/user/manage', {
 | 
					      const res = await API.post('/api/user/manage', {
 | 
				
			||||||
        username,
 | 
					        username,
 | 
				
			||||||
        action,
 | 
					        action
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
      const { success, message } = res.data;
 | 
					      const { success, message } = res.data;
 | 
				
			||||||
      if (success) {
 | 
					      if (success) {
 | 
				
			||||||
@@ -161,18 +162,18 @@ const UsersTable = () => {
 | 
				
			|||||||
            <Table.HeaderCell
 | 
					            <Table.HeaderCell
 | 
				
			||||||
              style={{ cursor: 'pointer' }}
 | 
					              style={{ cursor: 'pointer' }}
 | 
				
			||||||
              onClick={() => {
 | 
					              onClick={() => {
 | 
				
			||||||
                sortUser('username');
 | 
					                sortUser('id');
 | 
				
			||||||
              }}
 | 
					              }}
 | 
				
			||||||
            >
 | 
					            >
 | 
				
			||||||
              用户名
 | 
					              ID
 | 
				
			||||||
            </Table.HeaderCell>
 | 
					            </Table.HeaderCell>
 | 
				
			||||||
            <Table.HeaderCell
 | 
					            <Table.HeaderCell
 | 
				
			||||||
              style={{ cursor: 'pointer' }}
 | 
					              style={{ cursor: 'pointer' }}
 | 
				
			||||||
              onClick={() => {
 | 
					              onClick={() => {
 | 
				
			||||||
                sortUser('display_name');
 | 
					                sortUser('username');
 | 
				
			||||||
              }}
 | 
					              }}
 | 
				
			||||||
            >
 | 
					            >
 | 
				
			||||||
              显示名称
 | 
					              用户名
 | 
				
			||||||
            </Table.HeaderCell>
 | 
					            </Table.HeaderCell>
 | 
				
			||||||
            <Table.HeaderCell
 | 
					            <Table.HeaderCell
 | 
				
			||||||
              style={{ cursor: 'pointer' }}
 | 
					              style={{ cursor: 'pointer' }}
 | 
				
			||||||
@@ -220,9 +221,17 @@ const UsersTable = () => {
 | 
				
			|||||||
              if (user.deleted) return <></>;
 | 
					              if (user.deleted) return <></>;
 | 
				
			||||||
              return (
 | 
					              return (
 | 
				
			||||||
                <Table.Row key={user.id}>
 | 
					                <Table.Row key={user.id}>
 | 
				
			||||||
                  <Table.Cell>{user.username}</Table.Cell>
 | 
					                  <Table.Cell>{user.id}</Table.Cell>
 | 
				
			||||||
                  <Table.Cell>{user.display_name}</Table.Cell>
 | 
					                  <Table.Cell>
 | 
				
			||||||
                  <Table.Cell>{user.email ? user.email : '无'}</Table.Cell>
 | 
					                    <Popup
 | 
				
			||||||
 | 
					                      content={user.email ? user.email : '未绑定邮箱地址'}
 | 
				
			||||||
 | 
					                      key={user.display_name}
 | 
				
			||||||
 | 
					                      header={user.display_name ? user.display_name : user.username}
 | 
				
			||||||
 | 
					                      trigger={<span>{renderText(user.username, 10)}</span>}
 | 
				
			||||||
 | 
					                      hoverable
 | 
				
			||||||
 | 
					                    />
 | 
				
			||||||
 | 
					                  </Table.Cell>
 | 
				
			||||||
 | 
					                  <Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
 | 
				
			||||||
                  <Table.Cell>{user.quota}</Table.Cell>
 | 
					                  <Table.Cell>{user.quota}</Table.Cell>
 | 
				
			||||||
                  <Table.Cell>{renderRole(user.role)}</Table.Cell>
 | 
					                  <Table.Cell>{renderRole(user.role)}</Table.Cell>
 | 
				
			||||||
                  <Table.Cell>{renderStatus(user.status)}</Table.Cell>
 | 
					                  <Table.Cell>{renderStatus(user.status)}</Table.Cell>
 | 
				
			||||||
@@ -234,6 +243,7 @@ const UsersTable = () => {
 | 
				
			|||||||
                        onClick={() => {
 | 
					                        onClick={() => {
 | 
				
			||||||
                          manageUser(user.username, 'promote', idx);
 | 
					                          manageUser(user.username, 'promote', idx);
 | 
				
			||||||
                        }}
 | 
					                        }}
 | 
				
			||||||
 | 
					                        disabled={user.role === 100}
 | 
				
			||||||
                      >
 | 
					                      >
 | 
				
			||||||
                        提升
 | 
					                        提升
 | 
				
			||||||
                      </Button>
 | 
					                      </Button>
 | 
				
			||||||
@@ -243,12 +253,13 @@ const UsersTable = () => {
 | 
				
			|||||||
                        onClick={() => {
 | 
					                        onClick={() => {
 | 
				
			||||||
                          manageUser(user.username, 'demote', idx);
 | 
					                          manageUser(user.username, 'demote', idx);
 | 
				
			||||||
                        }}
 | 
					                        }}
 | 
				
			||||||
 | 
					                        disabled={user.role === 100}
 | 
				
			||||||
                      >
 | 
					                      >
 | 
				
			||||||
                        降级
 | 
					                        降级
 | 
				
			||||||
                      </Button>
 | 
					                      </Button>
 | 
				
			||||||
                      <Popup
 | 
					                      <Popup
 | 
				
			||||||
                        trigger={
 | 
					                        trigger={
 | 
				
			||||||
                          <Button size='small' negative>
 | 
					                          <Button size='small' negative disabled={user.role === 100}>
 | 
				
			||||||
                            删除
 | 
					                            删除
 | 
				
			||||||
                          </Button>
 | 
					                          </Button>
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
@@ -274,6 +285,7 @@ const UsersTable = () => {
 | 
				
			|||||||
                            idx
 | 
					                            idx
 | 
				
			||||||
                          );
 | 
					                          );
 | 
				
			||||||
                        }}
 | 
					                        }}
 | 
				
			||||||
 | 
					                        disabled={user.role === 100}
 | 
				
			||||||
                      >
 | 
					                      >
 | 
				
			||||||
                        {user.status === 1 ? '禁用' : '启用'}
 | 
					                        {user.status === 1 ? '禁用' : '启用'}
 | 
				
			||||||
                      </Button>
 | 
					                      </Button>
 | 
				
			||||||
@@ -281,6 +293,7 @@ const UsersTable = () => {
 | 
				
			|||||||
                        size={'small'}
 | 
					                        size={'small'}
 | 
				
			||||||
                        as={Link}
 | 
					                        as={Link}
 | 
				
			||||||
                        to={'/user/edit/' + user.id}
 | 
					                        to={'/user/edit/' + user.id}
 | 
				
			||||||
 | 
					                        disabled={user.role === 100}
 | 
				
			||||||
                      >
 | 
					                      >
 | 
				
			||||||
                        编辑
 | 
					                        编辑
 | 
				
			||||||
                      </Button>
 | 
					                      </Button>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,5 +6,6 @@ export const CHANNEL_OPTIONS = [
 | 
				
			|||||||
  { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
 | 
					  { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
 | 
				
			||||||
  { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
 | 
					  { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
 | 
				
			||||||
  { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
 | 
					  { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
 | 
				
			||||||
 | 
					  { key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
 | 
				
			||||||
  { key: 8, text: '自定义', value: 8, color: 'pink' }
 | 
					  { key: 8, text: '自定义', value: 8, color: 'pink' }
 | 
				
			||||||
];
 | 
					];
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										6
									
								
								web/src/helpers/render.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								web/src/helpers/render.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					export function renderText(text, limit) {
 | 
				
			||||||
 | 
					  if (text.length > limit) {
 | 
				
			||||||
 | 
					    return text.slice(0, limit - 3) + '...';
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return text;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user