mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			48 Commits
		
	
	
		
			v0.3.0-alp
			...
			v0.3.3-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					d4794fc051 | ||
| 
						 | 
					8b43e0dd3f | ||
| 
						 | 
					92c88fa273 | ||
| 
						 | 
					38191d55be | ||
| 
						 | 
					d9e39f5906 | ||
| 
						 | 
					17b7646c12 | ||
| 
						 | 
					171b818504 | ||
| 
						 | 
					bcca0cc0bc | ||
| 
						 | 
					b92ec5e54c | ||
| 
						 | 
					fa79e8b7a3 | ||
| 
						 | 
					1cc7c20183 | ||
| 
						 | 
					2eee97e9b6 | ||
| 
						 | 
					a3a1b612b0 | ||
| 
						 | 
					61e682ca47 | ||
| 
						 | 
					b383983106 | ||
| 
						 | 
					cfd587117e | ||
| 
						 | 
					ef9dca28f5 | ||
| 
						 | 
					741c0b9c18 | ||
| 
						 | 
					3711f4a741 | ||
| 
						 | 
					7c6bf3e97b | ||
| 
						 | 
					481ba41fbd | ||
| 
						 | 
					2779d6629c | ||
| 
						 | 
					e509899daf | ||
| 
						 | 
					b53cdbaf05 | ||
| 
						 | 
					ced89398a5 | ||
| 
						 | 
					09c2e3bcec | ||
| 
						 | 
					5cba800fa6 | ||
| 
						 | 
					2d39a135f2 | ||
| 
						 | 
					3c6834a79c | ||
| 
						 | 
					6da3410823 | ||
| 
						 | 
					ceb289cb4d | ||
| 
						 | 
					6f8cc712b0 | ||
| 
						 | 
					ad01e1f3b3 | ||
| 
						 | 
					cc1ef2ffd5 | ||
| 
						 | 
					7201bd1c97 | ||
| 
						 | 
					73d5e0f283 | ||
| 
						 | 
					efc744ca35 | ||
| 
						 | 
					e8da98139f | ||
| 
						 | 
					519cb030f7 | ||
| 
						 | 
					58fe923c85 | ||
| 
						 | 
					c9ac5e391f | ||
| 
						 | 
					69cf1de7bd | ||
| 
						 | 
					4d6172a242 | ||
| 
						 | 
					8afdc56b11 | ||
| 
						 | 
					a9ea1d9d10 | ||
| 
						 | 
					ea8e7c517b | ||
| 
						 | 
					d1e9b86f05 | ||
| 
						 | 
					6d1e5cb5dc | 
							
								
								
									
										52
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								README.md
									
									
									
									
									
								
							@@ -38,33 +38,41 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://openai.justsong.cn/">在线演示</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api#常见问题">常见问题</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## 功能
 | 
			
		||||
1. 支持多种 API 访问渠道,欢迎 PR 或提 issue 添加更多渠道:
 | 
			
		||||
   + [x] OpenAI 官方通道
 | 
			
		||||
   + [x] **Azure OpenAI API**
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [CloseAI](https://console.openai-asia.com)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
			
		||||
   + [x] [AI.LS](https://ai.ls)
 | 
			
		||||
   + [x] [OpenAI Max](https://openaimax.com)
 | 
			
		||||
   + [x] [OhMyGPT](https://www.ohmygpt.com)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [CloseAI](https://console.openai-asia.com/r/2412)
 | 
			
		||||
   + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
 | 
			
		||||
2. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
4. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
			
		||||
5. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。
 | 
			
		||||
6. 支持**通道管理**,批量创建通道。
 | 
			
		||||
7. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
8. 支持丰富的**自定义**设置,
 | 
			
		||||
4. 支持**多机部署**,[详见此处](#多机部署)。
 | 
			
		||||
5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
			
		||||
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
 | 
			
		||||
7. 支持**通道管理**,批量创建通道。
 | 
			
		||||
8. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
9. 支持丰富的**自定义**设置,
 | 
			
		||||
   1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
   2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
9. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
10. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
10. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
11. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
11. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
12. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
 | 
			
		||||
## 部署
 | 
			
		||||
### 基于 Docker 进行部署
 | 
			
		||||
@@ -87,13 +95,10 @@ server{
 | 
			
		||||
          proxy_set_header X-Forwarded-For $remote_addr;
 | 
			
		||||
          proxy_cache_bypass $http_upgrade;
 | 
			
		||||
          proxy_set_header Accept-Encoding gzip;
 | 
			
		||||
          proxy_buffering off;  # 重要:关闭代理缓冲
 | 
			
		||||
   }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
注意,为了 SSE 正常工作,需要关闭 Nginx 的代理缓冲。
 | 
			
		||||
 | 
			
		||||
之后使用 Let's Encrypt 的 certbot 配置 HTTPS:
 | 
			
		||||
```bash
 | 
			
		||||
# Ubuntu 安装 certbot:
 | 
			
		||||
@@ -130,6 +135,14 @@ sudo service nginx restart
 | 
			
		||||
 | 
			
		||||
更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
 | 
			
		||||
 | 
			
		||||
### 多机部署
 | 
			
		||||
1. 所有服务器 `SESSION_SECRET` 设置一样的值。
 | 
			
		||||
2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,请自行配置主备数据库同步。
 | 
			
		||||
3. 所有从服务器必须设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
 | 
			
		||||
4. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
 | 
			
		||||
 | 
			
		||||
环境变量的具体使用方法详见[此处](#环境变量)。
 | 
			
		||||
 | 
			
		||||
## 配置
 | 
			
		||||
系统本身开箱即用。
 | 
			
		||||
 | 
			
		||||
@@ -154,6 +167,10 @@ sudo service nginx restart
 | 
			
		||||
   + 例子:`SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。
 | 
			
		||||
   + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api`
 | 
			
		||||
4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。
 | 
			
		||||
   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
			
		||||
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
 | 
			
		||||
   + 例子:`SYNC_FREQUENCY=60`
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
@@ -171,3 +188,10 @@ https://openai.justsong.cn
 | 
			
		||||
### 截图展示
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
## 常见问题
 | 
			
		||||
1. 账户额度足够为什么提示额度不足?
 | 
			
		||||
   + 请检查你的令牌额度是否足够,这个和账户额度是分开的。
 | 
			
		||||
   + 令牌额度仅供用户设置最大使用量,用户可自由设置。
 | 
			
		||||
2. 宝塔部署后访问出现空白页面?
 | 
			
		||||
   + 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。
 | 
			
		||||
							
								
								
									
										6
									
								
								bin/migration_v0.2-v0.3.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								bin/migration_v0.2-v0.3.sql
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
UPDATE users
 | 
			
		||||
SET quota = quota + (
 | 
			
		||||
    SELECT SUM(remain_quota)
 | 
			
		||||
    FROM tokens
 | 
			
		||||
    WHERE tokens.user_id = users.id
 | 
			
		||||
)
 | 
			
		||||
@@ -54,6 +54,7 @@ var QuotaForNewUser = 0
 | 
			
		||||
var ChannelDisableThreshold = 5.0
 | 
			
		||||
var AutomaticDisableChannelEnabled = false
 | 
			
		||||
var QuotaRemindThreshold = 1000
 | 
			
		||||
var PreConsumedQuota = 500
 | 
			
		||||
 | 
			
		||||
var RootUserEmail = ""
 | 
			
		||||
 | 
			
		||||
@@ -126,16 +127,22 @@ const (
 | 
			
		||||
	ChannelTypeOpenAIMax = 6
 | 
			
		||||
	ChannelTypeOhMyGPT   = 7
 | 
			
		||||
	ChannelTypeCustom    = 8
 | 
			
		||||
	ChannelTypeAILS      = 9
 | 
			
		||||
	ChannelTypeAIProxy   = 10
 | 
			
		||||
	ChannelTypePaLM      = 11
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                            // 0
 | 
			
		||||
	"https://api.openai.com",      // 1
 | 
			
		||||
	"https://openai.api2d.net",    // 2
 | 
			
		||||
	"https://oa.api2d.net",        // 2
 | 
			
		||||
	"",                            // 3
 | 
			
		||||
	"https://api.openai-asia.com", // 4
 | 
			
		||||
	"https://api.openai-sb.com",   // 5
 | 
			
		||||
	"https://api.openaimax.com",   // 6
 | 
			
		||||
	"https://api.ohmygpt.com",     // 7
 | 
			
		||||
	"",                            // 8
 | 
			
		||||
	"https://api.caipacity.com",   // 9
 | 
			
		||||
	"https://api.aiproxy.io",      // 10
 | 
			
		||||
	"",                            // 11
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										41
									
								
								controller/billing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								controller/billing.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetSubscription(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	quota, err := model.GetUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
			Message: err.Error(),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"error": openAIError,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	subscription := OpenAISubscriptionResponse{
 | 
			
		||||
		Object:             "billing_subscription",
 | 
			
		||||
		HasPaymentMethod:   true,
 | 
			
		||||
		SoftLimitUSD:       float64(quota),
 | 
			
		||||
		HardLimitUSD:       float64(quota),
 | 
			
		||||
		SystemHardLimitUSD: float64(quota),
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, subscription)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUsage(c *gin.Context) {
 | 
			
		||||
	//userId := c.GetInt("id")
 | 
			
		||||
	// TODO: get usage from database
 | 
			
		||||
	usage := OpenAIUsageResponse{
 | 
			
		||||
		Object:     "list",
 | 
			
		||||
		TotalUsage: 0,
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, usage)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										175
									
								
								controller/channel-billing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								controller/channel-billing.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,175 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://github.com/songquanpeng/one-api/issues/79
 | 
			
		||||
 | 
			
		||||
type OpenAISubscriptionResponse struct {
 | 
			
		||||
	Object             string  `json:"object"`
 | 
			
		||||
	HasPaymentMethod   bool    `json:"has_payment_method"`
 | 
			
		||||
	SoftLimitUSD       float64 `json:"soft_limit_usd"`
 | 
			
		||||
	HardLimitUSD       float64 `json:"hard_limit_usd"`
 | 
			
		||||
	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIUsageDailyCost struct {
 | 
			
		||||
	Timestamp float64 `json:"timestamp"`
 | 
			
		||||
	LineItems []struct {
 | 
			
		||||
		Name string  `json:"name"`
 | 
			
		||||
		Cost float64 `json:"cost"`
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIUsageResponse struct {
 | 
			
		||||
	Object string `json:"object"`
 | 
			
		||||
	//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
 | 
			
		||||
	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	case common.ChannelTypeCustom:
 | 
			
		||||
		baseURL = channel.BaseURL
 | 
			
		||||
	}
 | 
			
		||||
	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
 | 
			
		||||
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	req, err := http.NewRequest("GET", url, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	auth := fmt.Sprintf("Bearer %s", channel.Key)
 | 
			
		||||
	req.Header.Add("Authorization", auth)
 | 
			
		||||
	res, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	body, err := io.ReadAll(res.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	err = res.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	subscription := OpenAISubscriptionResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &subscription)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
 | 
			
		||||
	//endDate := now.Format("2006-01-02")
 | 
			
		||||
	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01")
 | 
			
		||||
	req, err = http.NewRequest("GET", url, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Add("Authorization", auth)
 | 
			
		||||
	res, err = client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	body, err = io.ReadAll(res.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	err = res.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	usage := OpenAIUsageResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &usage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	balance := subscription.HardLimitUSD - usage.TotalUsage/100
 | 
			
		||||
	channel.UpdateBalance(balance)
 | 
			
		||||
	return balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelBalance(c *gin.Context) {
 | 
			
		||||
	id, err := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	channel, err := model.GetChannelById(id, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	balance, err := updateChannelBalance(channel)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"balance": balance,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateAllChannelsBalance() error {
 | 
			
		||||
	channels, err := model.GetAllChannels(0, 0, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, channel := range channels {
 | 
			
		||||
		if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		balance, err := updateChannelBalance(channel)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		} else {
 | 
			
		||||
			// err is nil & balance <= 0 means quota is used up
 | 
			
		||||
			if balance <= 0 {
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, "余额不足")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateAllChannelsBalance(c *gin.Context) {
 | 
			
		||||
	// TODO: make it async
 | 
			
		||||
	err := updateAllChannelsBalance()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -201,7 +201,7 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if response.Error.Type != "" {
 | 
			
		||||
	if response.Error.Message != "" || response.Error.Code != "" {
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -210,11 +210,12 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
			
		||||
func buildTestRequest(c *gin.Context) *ChatRequest {
 | 
			
		||||
	model_ := c.Query("model")
 | 
			
		||||
	testRequest := &ChatRequest{
 | 
			
		||||
		Model: model_,
 | 
			
		||||
		Model:     model_,
 | 
			
		||||
		MaxTokens: 1,
 | 
			
		||||
	}
 | 
			
		||||
	testMessage := Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
		Content: "echo hi",
 | 
			
		||||
		Content: "hi",
 | 
			
		||||
	}
 | 
			
		||||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
			
		||||
	return testRequest
 | 
			
		||||
@@ -264,14 +265,14 @@ var testAllChannelsLock sync.Mutex
 | 
			
		||||
var testAllChannelsRunning bool = false
 | 
			
		||||
 | 
			
		||||
// disable & notify
 | 
			
		||||
func disableChannel(channelId int, channelName string, err error) {
 | 
			
		||||
func disableChannel(channelId int, channelName string, reason string) {
 | 
			
		||||
	if common.RootUserEmail == "" {
 | 
			
		||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
			
		||||
	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
 | 
			
		||||
	err = common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | 
			
		||||
	err := common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
@@ -311,7 +312,7 @@ func testAllChannels(c *gin.Context) error {
 | 
			
		||||
				if milliseconds > disableThreshold {
 | 
			
		||||
					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				}
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err)
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										153
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,153 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/models/list
 | 
			
		||||
 | 
			
		||||
type OpenAIModelPermission struct {
 | 
			
		||||
	Id                 string  `json:"id"`
 | 
			
		||||
	Object             string  `json:"object"`
 | 
			
		||||
	Created            int     `json:"created"`
 | 
			
		||||
	AllowCreateEngine  bool    `json:"allow_create_engine"`
 | 
			
		||||
	AllowSampling      bool    `json:"allow_sampling"`
 | 
			
		||||
	AllowLogprobs      bool    `json:"allow_logprobs"`
 | 
			
		||||
	AllowSearchIndices bool    `json:"allow_search_indices"`
 | 
			
		||||
	AllowView          bool    `json:"allow_view"`
 | 
			
		||||
	AllowFineTuning    bool    `json:"allow_fine_tuning"`
 | 
			
		||||
	Organization       string  `json:"organization"`
 | 
			
		||||
	Group              *string `json:"group"`
 | 
			
		||||
	IsBlocking         bool    `json:"is_blocking"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIModels struct {
 | 
			
		||||
	Id         string                `json:"id"`
 | 
			
		||||
	Object     string                `json:"object"`
 | 
			
		||||
	Created    int                   `json:"created"`
 | 
			
		||||
	OwnedBy    string                `json:"owned_by"`
 | 
			
		||||
	Permission OpenAIModelPermission `json:"permission"`
 | 
			
		||||
	Root       string                `json:"root"`
 | 
			
		||||
	Parent     *string               `json:"parent"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var openAIModels []OpenAIModels
 | 
			
		||||
var openAIModelsMap map[string]OpenAIModels
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	permission := OpenAIModelPermission{
 | 
			
		||||
		Id:                 "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
 | 
			
		||||
		Object:             "model_permission",
 | 
			
		||||
		Created:            1626777600,
 | 
			
		||||
		AllowCreateEngine:  true,
 | 
			
		||||
		AllowSampling:      true,
 | 
			
		||||
		AllowLogprobs:      true,
 | 
			
		||||
		AllowSearchIndices: false,
 | 
			
		||||
		AllowView:          true,
 | 
			
		||||
		AllowFineTuning:    false,
 | 
			
		||||
		Organization:       "*",
 | 
			
		||||
		Group:              nil,
 | 
			
		||||
		IsBlocking:         false,
 | 
			
		||||
	}
 | 
			
		||||
	// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
	openAIModels = []OpenAIModels{
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo-0301",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo-0301",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-0314",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-0314",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-32k",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-32k",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-32k-0314",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-32k-0314",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-embedding-ada-002",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-embedding-ada-002",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
		openAIModelsMap[model.Id] = model
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListModels(c *gin.Context) {
 | 
			
		||||
	c.JSON(200, openAIModels)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RetrieveModel(c *gin.Context) {
 | 
			
		||||
	modelId := c.Param("model")
 | 
			
		||||
	if model, ok := openAIModelsMap[modelId]; ok {
 | 
			
		||||
		c.JSON(200, model)
 | 
			
		||||
	} else {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 | 
			
		||||
			Type:    "invalid_request_error",
 | 
			
		||||
			Param:   "model",
 | 
			
		||||
			Code:    "model_not_found",
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"error": openAIError,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										59
									
								
								controller/relay-palm.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								controller/relay-palm.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PaLMChatMessage struct {
 | 
			
		||||
	Author  string `json:"author"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PaLMFilter struct {
 | 
			
		||||
	Reason  string `json:"reason"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
			
		||||
type PaLMChatRequest struct {
 | 
			
		||||
	Prompt         []Message `json:"prompt"`
 | 
			
		||||
	Temperature    float64   `json:"temperature"`
 | 
			
		||||
	CandidateCount int       `json:"candidateCount"`
 | 
			
		||||
	TopP           float64   `json:"topP"`
 | 
			
		||||
	TopK           int       `json:"topK"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 | 
			
		||||
type PaLMChatResponse struct {
 | 
			
		||||
	Candidates []Message    `json:"candidates"`
 | 
			
		||||
	Messages   []Message    `json:"messages"`
 | 
			
		||||
	Filters    []PaLMFilter `json:"filters"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
 | 
			
		||||
	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
 | 
			
		||||
	for _, message := range openAIRequest.Messages {
 | 
			
		||||
		var author string
 | 
			
		||||
		if message.Role == "user" {
 | 
			
		||||
			author = "0"
 | 
			
		||||
		} else {
 | 
			
		||||
			author = "1"
 | 
			
		||||
		}
 | 
			
		||||
		messages = append(messages, PaLMChatMessage{
 | 
			
		||||
			Author:  author,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	request := PaLMChatRequest{
 | 
			
		||||
		Prompt:         nil,
 | 
			
		||||
		Temperature:    openAIRequest.Temperature,
 | 
			
		||||
		CandidateCount: openAIRequest.N,
 | 
			
		||||
		TopP:           openAIRequest.TopP,
 | 
			
		||||
		TopK:           openAIRequest.MaxTokens,
 | 
			
		||||
	}
 | 
			
		||||
	// TODO: forward request to PaLM & convert response
 | 
			
		||||
	fmt.Print(request)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										65
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
			
		||||
 | 
			
		||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
			
		||||
	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
			
		||||
		return tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
 | 
			
		||||
		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tokenEncoderMap[model] = tokenEncoder
 | 
			
		||||
	return tokenEncoder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func countTokenMessages(messages []Message, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	// Reference:
 | 
			
		||||
	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
			
		||||
	// https://github.com/pkoukk/tiktoken-go/issues/6
 | 
			
		||||
	//
 | 
			
		||||
	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
 | 
			
		||||
	var tokensPerMessage int
 | 
			
		||||
	var tokensPerName int
 | 
			
		||||
	if strings.HasPrefix(model, "gpt-3.5") {
 | 
			
		||||
		tokensPerMessage = 4
 | 
			
		||||
		tokensPerName = -1 // If there's a name, the role is omitted
 | 
			
		||||
	} else if strings.HasPrefix(model, "gpt-4") {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
	}
 | 
			
		||||
	tokenNum := 0
 | 
			
		||||
	for _, message := range messages {
 | 
			
		||||
		tokenNum += tokensPerMessage
 | 
			
		||||
		tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
 | 
			
		||||
		tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
 | 
			
		||||
		if message.Name != nil {
 | 
			
		||||
			tokenNum += tokensPerName
 | 
			
		||||
			tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
 | 
			
		||||
	return tokenNum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func countTokenText(text string, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	token := tokenEncoder.Encode(text, nil, nil)
 | 
			
		||||
	return len(token)
 | 
			
		||||
}
 | 
			
		||||
@@ -4,10 +4,8 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
@@ -16,19 +14,35 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
	Role    string  `json:"role"`
 | 
			
		||||
	Content string  `json:"content"`
 | 
			
		||||
	Name    *string `json:"name,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/chat
 | 
			
		||||
 | 
			
		||||
type GeneralOpenAIRequest struct {
 | 
			
		||||
	Model       string    `json:"model"`
 | 
			
		||||
	Messages    []Message `json:"messages"`
 | 
			
		||||
	Prompt      string    `json:"prompt"`
 | 
			
		||||
	Stream      bool      `json:"stream"`
 | 
			
		||||
	MaxTokens   int       `json:"max_tokens"`
 | 
			
		||||
	Temperature float64   `json:"temperature"`
 | 
			
		||||
	TopP        float64   `json:"top_p"`
 | 
			
		||||
	N           int       `json:"n"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Model    string    `json:"model"`
 | 
			
		||||
	Messages []Message `json:"messages"`
 | 
			
		||||
	Model     string    `json:"model"`
 | 
			
		||||
	Messages  []Message `json:"messages"`
 | 
			
		||||
	MaxTokens int       `json:"max_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TextRequest struct {
 | 
			
		||||
	Model    string    `json:"model"`
 | 
			
		||||
	Messages []Message `json:"messages"`
 | 
			
		||||
	Prompt   string    `json:"prompt"`
 | 
			
		||||
	Model     string    `json:"model"`
 | 
			
		||||
	Messages  []Message `json:"messages"`
 | 
			
		||||
	Prompt    string    `json:"prompt"`
 | 
			
		||||
	MaxTokens int       `json:"max_tokens"`
 | 
			
		||||
	//Stream   bool      `json:"stream"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -45,6 +59,11 @@ type OpenAIError struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIErrorWithStatusCode struct {
 | 
			
		||||
	OpenAIError
 | 
			
		||||
	StatusCode int `json:"status_code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TextResponse struct {
 | 
			
		||||
	Usage `json:"usage"`
 | 
			
		||||
	Error OpenAIError `json:"error"`
 | 
			
		||||
@@ -59,47 +78,55 @@ type StreamResponse struct {
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
 | 
			
		||||
 | 
			
		||||
func countToken(text string) int {
 | 
			
		||||
	token := tokenEncoder.Encode(text, nil, nil)
 | 
			
		||||
	return len(token)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Relay(c *gin.Context) {
 | 
			
		||||
	err := relayHelper(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"error": gin.H{
 | 
			
		||||
				"message": err.Error(),
 | 
			
		||||
				"type":    "one_api_error",
 | 
			
		||||
			},
 | 
			
		||||
		if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
			err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
			"error": err.OpenAIError,
 | 
			
		||||
		})
 | 
			
		||||
		if common.AutomaticDisableChannelEnabled {
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelName := c.GetString("channel_name")
 | 
			
		||||
			disableChannel(channelId, channelName, err)
 | 
			
		||||
			disableChannel(channelId, channelName, err.Message)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) error {
 | 
			
		||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	openAIError := OpenAIError{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
		Code:    code,
 | 
			
		||||
	}
 | 
			
		||||
	return &OpenAIErrorWithStatusCode{
 | 
			
		||||
		OpenAIError: openAIError,
 | 
			
		||||
		StatusCode:  statusCode,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	var textRequest TextRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure {
 | 
			
		||||
	var textRequest GeneralOpenAIRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 | 
			
		||||
		requestBody, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = c.Request.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(requestBody, &textRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		// Reset request body
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
@@ -127,10 +154,27 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		model_ = strings.TrimSuffix(model_, "-0301")
 | 
			
		||||
		model_ = strings.TrimSuffix(model_, "-0314")
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 | 
			
		||||
	} else if channelType == common.ChannelTypePaLM {
 | 
			
		||||
		err := relayPaLM(textRequest, c)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	if textRequest.MaxTokens != 0 {
 | 
			
		||||
		preConsumedTokens = promptTokens + textRequest.MaxTokens
 | 
			
		||||
	}
 | 
			
		||||
	ratio := common.GetModelRatio(textRequest.Model)
 | 
			
		||||
	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	if channelType == common.ChannelTypeAzure {
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
@@ -145,18 +189,18 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	isStream := resp.Header.Get("Content-Type") == "text/event-stream"
 | 
			
		||||
	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
	var streamResponseText string
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
@@ -168,18 +212,14 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
				completionRatio = 2
 | 
			
		||||
			}
 | 
			
		||||
			if isStream {
 | 
			
		||||
				var promptText string
 | 
			
		||||
				for _, message := range textRequest.Messages {
 | 
			
		||||
					promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
 | 
			
		||||
				}
 | 
			
		||||
				completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
 | 
			
		||||
				quota = countToken(promptText) + countToken(completionText)*completionRatio + 3
 | 
			
		||||
				responseTokens := countTokenText(streamResponseText, textRequest.Model)
 | 
			
		||||
				quota = promptTokens + responseTokens*completionRatio
 | 
			
		||||
			} else {
 | 
			
		||||
				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 | 
			
		||||
			}
 | 
			
		||||
			ratio := common.GetModelRatio(textRequest.Model)
 | 
			
		||||
			quota = int(float64(quota) * ratio)
 | 
			
		||||
			err := model.DecreaseTokenQuota(tokenId, quota)
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("Error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
@@ -208,6 +248,10 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		go func() {
 | 
			
		||||
			for scanner.Scan() {
 | 
			
		||||
				data := scanner.Text()
 | 
			
		||||
				if len(data) < 6 { // must be something wrong!
 | 
			
		||||
					common.SysError("Invalid stream response: " + data)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				data = data[6:]
 | 
			
		||||
				if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
@@ -228,6 +272,7 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
		c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
		c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
		c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
		c.Stream(func(w io.Writer) bool {
 | 
			
		||||
			select {
 | 
			
		||||
			case data := <-dataChan:
 | 
			
		||||
@@ -242,50 +287,60 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		})
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	} else {
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "read_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			if textResponse.Error.Type != "" {
 | 
			
		||||
				return errors.New(fmt.Sprintf("type %s, code %s, message %s",
 | 
			
		||||
					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
 | 
			
		||||
				return &OpenAIErrorWithStatusCode{
 | 
			
		||||
					OpenAIError: textResponse.Error,
 | 
			
		||||
					StatusCode:  resp.StatusCode,
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// Reset response body
 | 
			
		||||
			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
		}
 | 
			
		||||
		// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
		// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
		// So the client will be confused by the response.
 | 
			
		||||
		// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
		_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayNotImplemented(c *gin.Context) {
 | 
			
		||||
	err := OpenAIError{
 | 
			
		||||
		Message: "API not implemented",
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
		Param:   "",
 | 
			
		||||
		Code:    "api_not_implemented",
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": "Not Implemented",
 | 
			
		||||
			"type":    "one_api_error",
 | 
			
		||||
		},
 | 
			
		||||
		"error": err,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -160,7 +160,6 @@ func DeleteToken(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateToken(c *gin.Context) {
 | 
			
		||||
	isAdmin := c.GetInt("role") >= common.RoleAdminUser
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	statusOnly := c.Query("status_only")
 | 
			
		||||
	token := model.Token{}
 | 
			
		||||
@@ -191,7 +190,7 @@ func UpdateToken(c *gin.Context) {
 | 
			
		||||
		if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "令牌可用次数已用尽,无法启用,请先修改令牌剩余次数,或者设置为无限次数",
 | 
			
		||||
				"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
@@ -202,10 +201,8 @@ func UpdateToken(c *gin.Context) {
 | 
			
		||||
		// If you add more fields, please also update token.Update()
 | 
			
		||||
		cleanToken.Name = token.Name
 | 
			
		||||
		cleanToken.ExpiredTime = token.ExpiredTime
 | 
			
		||||
		if isAdmin {
 | 
			
		||||
			cleanToken.RemainQuota = token.RemainQuota
 | 
			
		||||
			cleanToken.UnlimitedQuota = token.UnlimitedQuota
 | 
			
		||||
		}
 | 
			
		||||
		cleanToken.RemainQuota = token.RemainQuota
 | 
			
		||||
		cleanToken.UnlimitedQuota = token.UnlimitedQuota
 | 
			
		||||
	}
 | 
			
		||||
	err = cleanToken.Update()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -467,6 +467,13 @@ func CreateUser(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err := common.Validate.Struct(&user); err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "输入不合法 " + err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if user.DisplayName == "" {
 | 
			
		||||
		user.DisplayName = user.Username
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.go
									
									
									
									
									
								
							@@ -47,6 +47,13 @@ func main() {
 | 
			
		||||
 | 
			
		||||
	// Initialize options
 | 
			
		||||
	model.InitOptionMap()
 | 
			
		||||
	if os.Getenv("SYNC_FREQUENCY") != "" {
 | 
			
		||||
		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog(err)
 | 
			
		||||
		}
 | 
			
		||||
		go model.SyncOptions(frequency)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
 
 | 
			
		||||
@@ -85,6 +85,8 @@ func RootAuth() func(c *gin.Context) {
 | 
			
		||||
func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
		key = strings.TrimPrefix(key, "Bearer ")
 | 
			
		||||
		key = strings.TrimPrefix(key, "sk-")
 | 
			
		||||
		parts := strings.Split(key, "-")
 | 
			
		||||
		key = parts[0]
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
@@ -111,7 +113,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
		c.Set("id", token.UserId)
 | 
			
		||||
		c.Set("token_id", token.Id)
 | 
			
		||||
		requestURL := c.Request.URL.String()
 | 
			
		||||
		consumeQuota := !token.UnlimitedQuota
 | 
			
		||||
		consumeQuota := true
 | 
			
		||||
		if strings.HasPrefix(requestURL, "/v1/models") {
 | 
			
		||||
			consumeQuota = false
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,11 @@ import (
 | 
			
		||||
 | 
			
		||||
func Cache() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		c.Header("Cache-Control", "max-age=604800") // one week
 | 
			
		||||
		if c.Request.RequestURI == "/" {
 | 
			
		||||
			c.Header("Cache-Control", "no-cache")
 | 
			
		||||
		} else {
 | 
			
		||||
			c.Header("Cache-Control", "max-age=604800") // one week
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,17 +6,19 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Channel struct {
 | 
			
		||||
	Id           int    `json:"id"`
 | 
			
		||||
	Type         int    `json:"type" gorm:"default:0"`
 | 
			
		||||
	Key          string `json:"key" gorm:"not null"`
 | 
			
		||||
	Status       int    `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name         string `json:"name" gorm:"index"`
 | 
			
		||||
	Weight       int    `json:"weight"`
 | 
			
		||||
	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	TestTime     int64  `json:"test_time" gorm:"bigint"`
 | 
			
		||||
	ResponseTime int    `json:"response_time"` // in milliseconds
 | 
			
		||||
	BaseURL      string `json:"base_url" gorm:"column:base_url"`
 | 
			
		||||
	Other        string `json:"other"`
 | 
			
		||||
	Id                 int     `json:"id"`
 | 
			
		||||
	Type               int     `json:"type" gorm:"default:0"`
 | 
			
		||||
	Key                string  `json:"key" gorm:"not null"`
 | 
			
		||||
	Status             int     `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name               string  `json:"name" gorm:"index"`
 | 
			
		||||
	Weight             int     `json:"weight"`
 | 
			
		||||
	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	TestTime           int64   `json:"test_time" gorm:"bigint"`
 | 
			
		||||
	ResponseTime       int     `json:"response_time"` // in milliseconds
 | 
			
		||||
	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
 | 
			
		||||
	Other              string  `json:"other"`
 | 
			
		||||
	Balance            float64 `json:"balance"` // in USD
 | 
			
		||||
	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
@@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) UpdateBalance(balance float64) {
 | 
			
		||||
	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
 | 
			
		||||
		BalanceUpdatedTime: common.GetTimestamp(),
 | 
			
		||||
		Balance:            balance,
 | 
			
		||||
	}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update balance: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) Delete() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Delete(channel).Error
 | 
			
		||||
 
 | 
			
		||||
@@ -26,6 +26,7 @@ func createRootAccountIfNeed() error {
 | 
			
		||||
			Status:      common.UserStatusEnabled,
 | 
			
		||||
			DisplayName: "Root User",
 | 
			
		||||
			AccessToken: common.GetUUID(),
 | 
			
		||||
			Quota:       100000000,
 | 
			
		||||
		}
 | 
			
		||||
		DB.Create(&rootUser)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Option struct {
 | 
			
		||||
@@ -55,9 +56,14 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["TurnstileSecretKey"] = ""
 | 
			
		||||
	common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
 | 
			
		||||
	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 | 
			
		||||
	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 | 
			
		||||
	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 | 
			
		||||
	common.OptionMap["TopUpLink"] = common.TopUpLink
 | 
			
		||||
	common.OptionMapRWMutex.Unlock()
 | 
			
		||||
	loadOptionsFromDatabase()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func loadOptionsFromDatabase() {
 | 
			
		||||
	options, _ := AllOption()
 | 
			
		||||
	for _, option := range options {
 | 
			
		||||
		err := updateOptionMap(option.Key, option.Value)
 | 
			
		||||
@@ -67,6 +73,14 @@ func InitOptionMap() {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SyncOptions(frequency int) {
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Duration(frequency) * time.Second)
 | 
			
		||||
		common.SysLog("Syncing options from database")
 | 
			
		||||
		loadOptionsFromDatabase()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateOption(key string, value string) error {
 | 
			
		||||
	// Save to database first
 | 
			
		||||
	option := Option{
 | 
			
		||||
@@ -159,6 +173,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		common.QuotaForNewUser, _ = strconv.Atoi(value)
 | 
			
		||||
	case "QuotaRemindThreshold":
 | 
			
		||||
		common.QuotaRemindThreshold, _ = strconv.Atoi(value)
 | 
			
		||||
	case "PreConsumedQuota":
 | 
			
		||||
		common.PreConsumedQuota, _ = strconv.Atoi(value)
 | 
			
		||||
	case "ModelRatio":
 | 
			
		||||
		err = common.UpdateModelRatioByJSONString(value)
 | 
			
		||||
	case "TopUpLink":
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,6 @@ import (
 | 
			
		||||
	_ "gorm.io/driver/sqlite"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Token struct {
 | 
			
		||||
@@ -38,7 +37,6 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
			
		||||
	if key == "" {
 | 
			
		||||
		return nil, errors.New("未提供 token")
 | 
			
		||||
	}
 | 
			
		||||
	key = strings.Replace(key, "Bearer ", "", 1)
 | 
			
		||||
	token = &Token{}
 | 
			
		||||
	err = DB.Where("`key` = ?", key).First(token).Error
 | 
			
		||||
	if err == nil {
 | 
			
		||||
@@ -130,7 +128,23 @@ func DeleteTokenById(id int, userId int) (err error) {
 | 
			
		||||
	return token.Delete()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
@@ -138,7 +152,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if token.RemainQuota < quota {
 | 
			
		||||
	if !token.UnlimitedQuota && token.RemainQuota < quota {
 | 
			
		||||
		return errors.New("令牌额度不足")
 | 
			
		||||
	}
 | 
			
		||||
	userQuota, err := GetUserQuota(token.UserId)
 | 
			
		||||
@@ -163,17 +177,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
			if email != "" {
 | 
			
		||||
				topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
 | 
			
		||||
				err = common.SendEmail(prompt, email,
 | 
			
		||||
					fmt.Sprintf("%s,剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota-quota, topUpLink, topUpLink))
 | 
			
		||||
					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("发送邮件失败:" + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	if !token.UnlimitedQuota {
 | 
			
		||||
		err = DecreaseTokenQuota(tokenId, quota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	err = DecreaseUserQuota(token.UserId, quota)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
	token, err := GetTokenById(tokenId)
 | 
			
		||||
	if quota > 0 {
 | 
			
		||||
		err = DecreaseUserQuota(token.UserId, quota)
 | 
			
		||||
	} else {
 | 
			
		||||
		err = IncreaseUserQuota(token.UserId, -quota)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if !token.UnlimitedQuota {
 | 
			
		||||
		if quota > 0 {
 | 
			
		||||
			err = DecreaseTokenQuota(tokenId, quota)
 | 
			
		||||
		} else {
 | 
			
		||||
			err = IncreaseTokenQuota(tokenId, -quota)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,8 +19,7 @@ type User struct {
 | 
			
		||||
	Email            string `json:"email" gorm:"index" validate:"max=50"`
 | 
			
		||||
	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
			
		||||
	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	Balance          int    `json:"balance" gorm:"type:int;default:0"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"type:int;default:0"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
			channelRoute.GET("/:id", controller.GetChannel)
 | 
			
		||||
			channelRoute.GET("/test", controller.TestAllChannels)
 | 
			
		||||
			channelRoute.GET("/test/:id", controller.TestChannel)
 | 
			
		||||
			channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
 | 
			
		||||
			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
 | 
			
		||||
			channelRoute.POST("/", controller.AddChannel)
 | 
			
		||||
			channelRoute.PUT("/", controller.UpdateChannel)
 | 
			
		||||
			channelRoute.DELETE("/:id", controller.DeleteChannel)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,11 +8,14 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetDashboardRouter(router *gin.Engine) {
 | 
			
		||||
	apiRouter := router.Group("/dashboard")
 | 
			
		||||
	apiRouter := router.Group("/")
 | 
			
		||||
	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	apiRouter.Use(middleware.GlobalAPIRateLimit())
 | 
			
		||||
	apiRouter.Use(middleware.TokenAuth())
 | 
			
		||||
	{
 | 
			
		||||
		apiRouter.GET("/billing/credit_grants", controller.GetTokenStatus)
 | 
			
		||||
		apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
 | 
			
		||||
		apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
 | 
			
		||||
		apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
 | 
			
		||||
		apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,24 @@ package router
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"embed"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
	SetApiRouter(router)
 | 
			
		||||
	SetDashboardRouter(router)
 | 
			
		||||
	SetRelayRouter(router)
 | 
			
		||||
	setWebRouter(router, buildFS, indexPage)
 | 
			
		||||
	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
 | 
			
		||||
	if frontendBaseUrl == "" {
 | 
			
		||||
		SetWebRouter(router, buildFS, indexPage)
 | 
			
		||||
	} else {
 | 
			
		||||
		frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
 | 
			
		||||
		router.NoRoute(func(c *gin.Context) {
 | 
			
		||||
			c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,8 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	relayV1Router := router.Group("/v1")
 | 
			
		||||
	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 | 
			
		||||
	{
 | 
			
		||||
		relayV1Router.GET("/models", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/models/:model", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/models", controller.ListModels)
 | 
			
		||||
		relayV1Router.GET("/models/:model", controller.RetrieveModel)
 | 
			
		||||
		relayV1Router.POST("/completions", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/chat/completions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/edits", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -10,12 +10,13 @@ import (
 | 
			
		||||
	"one-api/middleware"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func setWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
	router.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	router.Use(middleware.GlobalWebRateLimit())
 | 
			
		||||
	router.Use(middleware.Cache())
 | 
			
		||||
	router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
 | 
			
		||||
	router.NoRoute(func(c *gin.Context) {
 | 
			
		||||
		c.Header("Cache-Control", "no-cache")
 | 
			
		||||
		c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,7 @@ const ChannelsTable = () => {
 | 
			
		||||
  const [activePage, setActivePage] = useState(1);
 | 
			
		||||
  const [searchKeyword, setSearchKeyword] = useState('');
 | 
			
		||||
  const [searching, setSearching] = useState(false);
 | 
			
		||||
  const [updatingBalance, setUpdatingBalance] = useState(false);
 | 
			
		||||
 | 
			
		||||
  const loadChannels = async (startIdx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/?p=${startIdx}`);
 | 
			
		||||
@@ -63,7 +64,7 @@ const ChannelsTable = () => {
 | 
			
		||||
  const refresh = async () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    await loadChannels(0);
 | 
			
		||||
  }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    loadChannels(0)
 | 
			
		||||
@@ -127,7 +128,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
  const renderResponseTime = (responseTime) => {
 | 
			
		||||
    let time = responseTime / 1000;
 | 
			
		||||
    time = time.toFixed(2) + " 秒";
 | 
			
		||||
    time = time.toFixed(2) + ' 秒';
 | 
			
		||||
    if (responseTime === 0) {
 | 
			
		||||
      return <Label basic color='grey'>未测试</Label>;
 | 
			
		||||
    } else if (responseTime <= 1000) {
 | 
			
		||||
@@ -179,11 +180,38 @@ const ChannelsTable = () => {
 | 
			
		||||
    const res = await API.get(`/api/channel/test`);
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。");
 | 
			
		||||
      showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const updateChannelBalance = async (id, name, idx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/update_balance/${id}/`);
 | 
			
		||||
    const { success, message, balance } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      let newChannels = [...channels];
 | 
			
		||||
      let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
 | 
			
		||||
      newChannels[realIdx].balance = balance;
 | 
			
		||||
      newChannels[realIdx].balance_updated_time = Date.now() / 1000;
 | 
			
		||||
      setChannels(newChannels);
 | 
			
		||||
      showInfo(`通道 ${name} 余额更新成功!`);
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const updateAllChannelsBalance = async () => {
 | 
			
		||||
    setUpdatingBalance(true);
 | 
			
		||||
    const res = await API.get(`/api/channel/update_balance`);
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showInfo('已更新完毕所有已启用通道余额!');
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
    setUpdatingBalance(false);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const handleKeywordChange = async (e, { value }) => {
 | 
			
		||||
    setSearchKeyword(value.trim());
 | 
			
		||||
@@ -263,10 +291,10 @@ const ChannelsTable = () => {
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortChannel('test_time');
 | 
			
		||||
                sortChannel('balance');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              测试时间
 | 
			
		||||
              余额
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
			
		||||
          </Table.Row>
 | 
			
		||||
@@ -286,8 +314,22 @@ const ChannelsTable = () => {
 | 
			
		||||
                  <Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderType(channel.type)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
 | 
			
		||||
                      key={channel.id}
 | 
			
		||||
                      trigger={renderResponseTime(channel.response_time)}
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
 | 
			
		||||
                      key={channel.id}
 | 
			
		||||
                      trigger={<span>${channel.balance.toFixed(2)}</span>}
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <div>
 | 
			
		||||
                      <Button
 | 
			
		||||
@@ -299,6 +341,16 @@ const ChannelsTable = () => {
 | 
			
		||||
                      >
 | 
			
		||||
                        测试
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      <Button
 | 
			
		||||
                        size={'small'}
 | 
			
		||||
                        positive
 | 
			
		||||
                        loading={updatingBalance}
 | 
			
		||||
                        onClick={() => {
 | 
			
		||||
                          updateChannelBalance(channel.id, channel.name, idx);
 | 
			
		||||
                        }}
 | 
			
		||||
                      >
 | 
			
		||||
                        更新余额
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      <Popup
 | 
			
		||||
                        trigger={
 | 
			
		||||
                          <Button size='small' negative>
 | 
			
		||||
@@ -353,6 +405,7 @@ const ChannelsTable = () => {
 | 
			
		||||
              <Button size='small' loading={loading} onClick={testAllChannels}>
 | 
			
		||||
                测试所有已启用通道
 | 
			
		||||
              </Button>
 | 
			
		||||
              <Button size='small' onClick={updateAllChannelsBalance} loading={updatingBalance}>更新所有已启用通道余额</Button>
 | 
			
		||||
              <Pagination
 | 
			
		||||
                floated='right'
 | 
			
		||||
                activePage={activePage}
 | 
			
		||||
 
 | 
			
		||||
@@ -34,7 +34,6 @@ const headerButtons = [
 | 
			
		||||
    name: '充值',
 | 
			
		||||
    to: '/topup',
 | 
			
		||||
    icon: 'cart',
 | 
			
		||||
    admin: true,
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
    name: '用户',
 | 
			
		||||
 
 | 
			
		||||
@@ -112,13 +112,17 @@ const PersonalSetting = () => {
 | 
			
		||||
      <Button onClick={generateAccessToken}>生成系统访问令牌</Button>
 | 
			
		||||
      <Divider />
 | 
			
		||||
      <Header as='h3'>账号绑定</Header>
 | 
			
		||||
      <Button
 | 
			
		||||
        onClick={() => {
 | 
			
		||||
          setShowWeChatBindModal(true);
 | 
			
		||||
        }}
 | 
			
		||||
      >
 | 
			
		||||
        绑定微信账号
 | 
			
		||||
      </Button>
 | 
			
		||||
      {
 | 
			
		||||
        status.wechat_login && (
 | 
			
		||||
          <Button
 | 
			
		||||
            onClick={() => {
 | 
			
		||||
              setShowWeChatBindModal(true);
 | 
			
		||||
            }}
 | 
			
		||||
          >
 | 
			
		||||
            绑定微信账号
 | 
			
		||||
          </Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Modal
 | 
			
		||||
        onClose={() => setShowWeChatBindModal(false)}
 | 
			
		||||
        onOpen={() => setShowWeChatBindModal(true)}
 | 
			
		||||
@@ -148,7 +152,11 @@ const PersonalSetting = () => {
 | 
			
		||||
          </Modal.Description>
 | 
			
		||||
        </Modal.Content>
 | 
			
		||||
      </Modal>
 | 
			
		||||
      <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
 | 
			
		||||
      {
 | 
			
		||||
        status.github_oauth && (
 | 
			
		||||
          <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Button
 | 
			
		||||
        onClick={() => {
 | 
			
		||||
          setShowEmailBindModal(true);
 | 
			
		||||
 
 | 
			
		||||
@@ -28,6 +28,7 @@ const SystemSetting = () => {
 | 
			
		||||
    RegisterEnabled: '',
 | 
			
		||||
    QuotaForNewUser: 0,
 | 
			
		||||
    QuotaRemindThreshold: 0,
 | 
			
		||||
    PreConsumedQuota: 0,
 | 
			
		||||
    ModelRatio: '',
 | 
			
		||||
    TopUpLink: '',
 | 
			
		||||
    AutomaticDisableChannelEnabled: '',
 | 
			
		||||
@@ -98,6 +99,7 @@ const SystemSetting = () => {
 | 
			
		||||
      name === 'TurnstileSecretKey' ||
 | 
			
		||||
      name === 'QuotaForNewUser' ||
 | 
			
		||||
      name === 'QuotaRemindThreshold' ||
 | 
			
		||||
      name === 'PreConsumedQuota' ||
 | 
			
		||||
      name === 'ModelRatio' ||
 | 
			
		||||
      name === 'TopUpLink'
 | 
			
		||||
    ) {
 | 
			
		||||
@@ -119,6 +121,9 @@ const SystemSetting = () => {
 | 
			
		||||
    if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
 | 
			
		||||
      await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
 | 
			
		||||
    }
 | 
			
		||||
    if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
 | 
			
		||||
      await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
 | 
			
		||||
    }
 | 
			
		||||
    if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
 | 
			
		||||
      if (!verifyJSON(inputs.ModelRatio)) {
 | 
			
		||||
        showError('模型倍率不是合法的 JSON 字符串');
 | 
			
		||||
@@ -272,7 +277,7 @@ const SystemSetting = () => {
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            运营设置
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group widths={3}>
 | 
			
		||||
          <Form.Group widths={4}>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='新用户初始配额'
 | 
			
		||||
              name='QuotaForNewUser'
 | 
			
		||||
@@ -302,6 +307,16 @@ const SystemSetting = () => {
 | 
			
		||||
              min='0'
 | 
			
		||||
              placeholder='低于此额度时将发送邮件提醒用户'
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='请求预扣费额度'
 | 
			
		||||
              name='PreConsumedQuota'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              value={inputs.PreConsumedQuota}
 | 
			
		||||
              type='number'
 | 
			
		||||
              min='0'
 | 
			
		||||
              placeholder='请求结束后多退少补'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group widths='equal'>
 | 
			
		||||
            <Form.TextArea
 | 
			
		||||
@@ -321,7 +336,7 @@ const SystemSetting = () => {
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group widths={3}>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='最长回应时间'
 | 
			
		||||
              label='最长响应时间'
 | 
			
		||||
              name='ChannelDisableThreshold'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showSuccess } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import { ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import { renderText } from '../helpers/render';
 | 
			
		||||
 | 
			
		||||
function renderRole(role) {
 | 
			
		||||
  switch (role) {
 | 
			
		||||
@@ -64,7 +65,7 @@ const UsersTable = () => {
 | 
			
		||||
    (async () => {
 | 
			
		||||
      const res = await API.post('/api/user/manage', {
 | 
			
		||||
        username,
 | 
			
		||||
        action,
 | 
			
		||||
        action
 | 
			
		||||
      });
 | 
			
		||||
      const { success, message } = res.data;
 | 
			
		||||
      if (success) {
 | 
			
		||||
@@ -161,18 +162,18 @@ const UsersTable = () => {
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortUser('username');
 | 
			
		||||
                sortUser('id');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              用户名
 | 
			
		||||
              ID
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortUser('display_name');
 | 
			
		||||
                sortUser('username');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              显示名称
 | 
			
		||||
              用户名
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
@@ -220,9 +221,17 @@ const UsersTable = () => {
 | 
			
		||||
              if (user.deleted) return <></>;
 | 
			
		||||
              return (
 | 
			
		||||
                <Table.Row key={user.id}>
 | 
			
		||||
                  <Table.Cell>{user.username}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.display_name}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.email ? user.email : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.id}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={user.email ? user.email : '未绑定邮箱地址'}
 | 
			
		||||
                      key={user.display_name}
 | 
			
		||||
                      header={user.display_name ? user.display_name : user.username}
 | 
			
		||||
                      trigger={<span>{renderText(user.username, 10)}</span>}
 | 
			
		||||
                      hoverable
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.quota}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderRole(user.role)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderStatus(user.status)}</Table.Cell>
 | 
			
		||||
@@ -234,6 +243,7 @@ const UsersTable = () => {
 | 
			
		||||
                        onClick={() => {
 | 
			
		||||
                          manageUser(user.username, 'promote', idx);
 | 
			
		||||
                        }}
 | 
			
		||||
                        disabled={user.role === 100}
 | 
			
		||||
                      >
 | 
			
		||||
                        提升
 | 
			
		||||
                      </Button>
 | 
			
		||||
@@ -243,12 +253,13 @@ const UsersTable = () => {
 | 
			
		||||
                        onClick={() => {
 | 
			
		||||
                          manageUser(user.username, 'demote', idx);
 | 
			
		||||
                        }}
 | 
			
		||||
                        disabled={user.role === 100}
 | 
			
		||||
                      >
 | 
			
		||||
                        降级
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      <Popup
 | 
			
		||||
                        trigger={
 | 
			
		||||
                          <Button size='small' negative>
 | 
			
		||||
                          <Button size='small' negative disabled={user.role === 100}>
 | 
			
		||||
                            删除
 | 
			
		||||
                          </Button>
 | 
			
		||||
                        }
 | 
			
		||||
@@ -274,6 +285,7 @@ const UsersTable = () => {
 | 
			
		||||
                            idx
 | 
			
		||||
                          );
 | 
			
		||||
                        }}
 | 
			
		||||
                        disabled={user.role === 100}
 | 
			
		||||
                      >
 | 
			
		||||
                        {user.status === 1 ? '禁用' : '启用'}
 | 
			
		||||
                      </Button>
 | 
			
		||||
@@ -281,6 +293,7 @@ const UsersTable = () => {
 | 
			
		||||
                        size={'small'}
 | 
			
		||||
                        as={Link}
 | 
			
		||||
                        to={'/user/edit/' + user.id}
 | 
			
		||||
                        disabled={user.role === 100}
 | 
			
		||||
                      >
 | 
			
		||||
                        编辑
 | 
			
		||||
                      </Button>
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,12 @@
 | 
			
		||||
export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 1, text: 'OpenAI', value: 1, color: 'green' },
 | 
			
		||||
  { key: 2, text: 'API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 3, text: 'Azure', value: 3, color: 'olive' },
 | 
			
		||||
  { key: 2, text: 'API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
 | 
			
		||||
  { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
 | 
			
		||||
  { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
 | 
			
		||||
  { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
 | 
			
		||||
  { key: 8, text: '自定义', value: 8, color: 'pink' }
 | 
			
		||||
  { key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
 | 
			
		||||
  { key: 10, text: 'AI Proxy', value: 10, color: 'purple' }
 | 
			
		||||
];
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								web/src/helpers/render.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								web/src/helpers/render.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
export function renderText(text, limit) {
 | 
			
		||||
  if (text.length > limit) {
 | 
			
		||||
    return text.slice(0, limit - 3) + '...';
 | 
			
		||||
  }
 | 
			
		||||
  return text;
 | 
			
		||||
}
 | 
			
		||||
@@ -46,6 +46,9 @@ const EditChannel = () => {
 | 
			
		||||
    if (localInputs.base_url.endsWith('/')) {
 | 
			
		||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = '2023-03-15-preview';
 | 
			
		||||
    }
 | 
			
		||||
    let res;
 | 
			
		||||
    if (isEdit) {
 | 
			
		||||
      res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Header, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { useParams } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
 | 
			
		||||
 | 
			
		||||
@@ -106,6 +106,7 @@ const EditToken = () => {
 | 
			
		||||
              required={!isEdit}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Message>注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。</Message>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='额度'
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
 | 
			
		||||
import { API, showError, showSuccess } from '../../helpers';
 | 
			
		||||
import { API, showError, showInfo, showSuccess } from '../../helpers';
 | 
			
		||||
 | 
			
		||||
const TopUp = () => {
 | 
			
		||||
  const [redemptionCode, setRedemptionCode] = useState('');
 | 
			
		||||
@@ -9,6 +9,7 @@ const TopUp = () => {
 | 
			
		||||
 | 
			
		||||
  const topUp = async () => {
 | 
			
		||||
    if (redemptionCode === '') {
 | 
			
		||||
      showInfo('请输入充值码!')
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    const res = await API.post('/api/user/topup', {
 | 
			
		||||
@@ -80,7 +81,7 @@ const TopUp = () => {
 | 
			
		||||
        <Grid.Column>
 | 
			
		||||
          <Statistic.Group widths='one'>
 | 
			
		||||
            <Statistic>
 | 
			
		||||
              <Statistic.Value>{userQuota}</Statistic.Value>
 | 
			
		||||
              <Statistic.Value>{userQuota.toLocaleString()}</Statistic.Value>
 | 
			
		||||
              <Statistic.Label>剩余额度</Statistic.Label>
 | 
			
		||||
            </Statistic>
 | 
			
		||||
          </Statistic.Group>
 | 
			
		||||
 
 | 
			
		||||
@@ -14,8 +14,9 @@ const EditUser = () => {
 | 
			
		||||
    github_id: '',
 | 
			
		||||
    wechat_id: '',
 | 
			
		||||
    email: '',
 | 
			
		||||
    quota: 0,
 | 
			
		||||
  });
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email } =
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email, quota } =
 | 
			
		||||
    inputs;
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
@@ -44,7 +45,11 @@ const EditUser = () => {
 | 
			
		||||
  const submit = async () => {
 | 
			
		||||
    let res = undefined;
 | 
			
		||||
    if (userId) {
 | 
			
		||||
      res = await API.put(`/api/user/`, { ...inputs, id: parseInt(userId) });
 | 
			
		||||
      let data = { ...inputs, id: parseInt(userId) };
 | 
			
		||||
      if (typeof data.quota === 'string') {
 | 
			
		||||
        data.quota = parseInt(data.quota);
 | 
			
		||||
      }
 | 
			
		||||
      res = await API.put(`/api/user/`, data);
 | 
			
		||||
    } else {
 | 
			
		||||
      res = await API.put(`/api/user/self`, inputs);
 | 
			
		||||
    }
 | 
			
		||||
@@ -92,6 +97,21 @@ const EditUser = () => {
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          {
 | 
			
		||||
            userId && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='剩余额度'
 | 
			
		||||
                  name='quota'
 | 
			
		||||
                  placeholder={'请输入新的剩余额度'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={quota}
 | 
			
		||||
                  type={'number'}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='已绑定的 GitHub 账户'
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user