mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			53 Commits
		
	
	
		
			v0.3.0-alp
			...
			v0.3.4
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					1e1c6a828f | ||
| 
						 | 
					2847a08852 | ||
| 
						 | 
					98f1a627f0 | ||
| 
						 | 
					333e4216d2 | ||
| 
						 | 
					7e80e2da3a | ||
| 
						 | 
					139624b8a4 | ||
| 
						 | 
					2f44aaa645 | ||
| 
						 | 
					0f6958c57a | ||
| 
						 | 
					5f045f8cf5 | ||
| 
						 | 
					f19ee05351 | ||
| 
						 | 
					fa71daa8a7 | ||
| 
						 | 
					54215dc303 | ||
| 
						 | 
					f9f42997b2 | ||
| 
						 | 
					25eab0b224 | ||
| 
						 | 
					34bce5b464 | ||
| 
						 | 
					d4794fc051 | ||
| 
						 | 
					8b43e0dd3f | ||
| 
						 | 
					92c88fa273 | ||
| 
						 | 
					38191d55be | ||
| 
						 | 
					d9e39f5906 | ||
| 
						 | 
					17b7646c12 | ||
| 
						 | 
					171b818504 | ||
| 
						 | 
					bcca0cc0bc | ||
| 
						 | 
					b92ec5e54c | ||
| 
						 | 
					fa79e8b7a3 | ||
| 
						 | 
					1cc7c20183 | ||
| 
						 | 
					2eee97e9b6 | ||
| 
						 | 
					a3a1b612b0 | ||
| 
						 | 
					61e682ca47 | ||
| 
						 | 
					b383983106 | ||
| 
						 | 
					cfd587117e | ||
| 
						 | 
					ef9dca28f5 | ||
| 
						 | 
					741c0b9c18 | ||
| 
						 | 
					3711f4a741 | ||
| 
						 | 
					7c6bf3e97b | ||
| 
						 | 
					481ba41fbd | ||
| 
						 | 
					2779d6629c | ||
| 
						 | 
					e509899daf | ||
| 
						 | 
					b53cdbaf05 | ||
| 
						 | 
					ced89398a5 | ||
| 
						 | 
					09c2e3bcec | ||
| 
						 | 
					5cba800fa6 | ||
| 
						 | 
					2d39a135f2 | ||
| 
						 | 
					3c6834a79c | ||
| 
						 | 
					6da3410823 | ||
| 
						 | 
					ceb289cb4d | ||
| 
						 | 
					6f8cc712b0 | ||
| 
						 | 
					ad01e1f3b3 | ||
| 
						 | 
					cc1ef2ffd5 | ||
| 
						 | 
					7201bd1c97 | ||
| 
						 | 
					73d5e0f283 | ||
| 
						 | 
					efc744ca35 | ||
| 
						 | 
					e8da98139f | 
							
								
								
									
										49
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								README.md
									
									
									
									
									
								
							@@ -38,6 +38,8 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api#截图展示">截图展示</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://openai.justsong.cn/">在线演示</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api#常见问题">常见问题</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **Warning**:从 `v0.2` 版本升级到 `v0.3` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.2-v0.3.sql)。
 | 
			
		||||
@@ -48,26 +50,29 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 | 
			
		||||
   + [x] OpenAI 官方通道
 | 
			
		||||
   + [x] **Azure OpenAI API**
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [CloseAI](https://console.openai-asia.com)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
			
		||||
   + [x] [AI.LS](https://ai.ls)
 | 
			
		||||
   + [x] [OpenAI Max](https://openaimax.com)
 | 
			
		||||
   + [x] [OhMyGPT](https://www.ohmygpt.com)
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [CloseAI](https://console.openai-asia.com/r/2412)
 | 
			
		||||
   + [x] 自定义渠道:例如使用自行搭建的 OpenAI 代理
 | 
			
		||||
2. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
4. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
			
		||||
5. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为令牌进行充值。
 | 
			
		||||
6. 支持**通道管理**,批量创建通道。
 | 
			
		||||
7. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
8. 支持丰富的**自定义**设置,
 | 
			
		||||
4. 支持**多机部署**,[详见此处](#多机部署)。
 | 
			
		||||
5. 支持**令牌管理**,设置令牌的过期时间和使用次数。
 | 
			
		||||
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
 | 
			
		||||
7. 支持**通道管理**,批量创建通道。
 | 
			
		||||
8. 支持发布公告,设置充值链接,设置新用户初始额度。
 | 
			
		||||
9. 支持丰富的**自定义**设置,
 | 
			
		||||
   1. 支持自定义系统名称,logo 以及页脚。
 | 
			
		||||
   2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
 | 
			
		||||
9. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
10. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
10. 支持通过系统访问令牌访问管理 API。
 | 
			
		||||
11. 支持用户管理,支持**多种用户登录注册方式**:
 | 
			
		||||
    + 邮箱登录注册以及通过邮箱进行密码重置。
 | 
			
		||||
    + [GitHub 开放授权](https://github.com/settings/applications/new)。
 | 
			
		||||
    + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
 | 
			
		||||
11. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
12. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
 | 
			
		||||
 | 
			
		||||
## 部署
 | 
			
		||||
### 基于 Docker 进行部署
 | 
			
		||||
@@ -90,13 +95,10 @@ server{
 | 
			
		||||
          proxy_set_header X-Forwarded-For $remote_addr;
 | 
			
		||||
          proxy_cache_bypass $http_upgrade;
 | 
			
		||||
          proxy_set_header Accept-Encoding gzip;
 | 
			
		||||
          proxy_buffering off;  # 重要:关闭代理缓冲
 | 
			
		||||
   }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
注意,为了 SSE 正常工作,需要关闭 Nginx 的代理缓冲。
 | 
			
		||||
 | 
			
		||||
之后使用 Let's Encrypt 的 certbot 配置 HTTPS:
 | 
			
		||||
```bash
 | 
			
		||||
# Ubuntu 安装 certbot:
 | 
			
		||||
@@ -133,6 +135,14 @@ sudo service nginx restart
 | 
			
		||||
 | 
			
		||||
更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。
 | 
			
		||||
 | 
			
		||||
### 多机部署
 | 
			
		||||
1. 所有服务器 `SESSION_SECRET` 设置一样的值。
 | 
			
		||||
2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,请自行配置主备数据库同步。
 | 
			
		||||
3. 所有从服务器必须设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
 | 
			
		||||
4. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
 | 
			
		||||
 | 
			
		||||
环境变量的具体使用方法详见[此处](#环境变量)。
 | 
			
		||||
 | 
			
		||||
## 配置
 | 
			
		||||
系统本身开箱即用。
 | 
			
		||||
 | 
			
		||||
@@ -157,6 +167,10 @@ sudo service nginx restart
 | 
			
		||||
   + 例子:`SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite。
 | 
			
		||||
   + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/one-api`
 | 
			
		||||
4. `FRONTEND_BASE_URL`:设置之后将使用指定的前端地址,而非后端地址。
 | 
			
		||||
   + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
			
		||||
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
 | 
			
		||||
   + 例子:`SYNC_FREQUENCY=60`
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
@@ -174,3 +188,10 @@ https://openai.justsong.cn
 | 
			
		||||
### 截图展示
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
## 常见问题
 | 
			
		||||
1. 账户额度足够为什么提示额度不足?
 | 
			
		||||
   + 请检查你的令牌额度是否足够,这个和账户额度是分开的。
 | 
			
		||||
   + 令牌额度仅供用户设置最大使用量,用户可自由设置。
 | 
			
		||||
2. 宝塔部署后访问出现空白页面?
 | 
			
		||||
   + 自动配置的问题,详见[#97](https://github.com/songquanpeng/one-api/issues/97)。
 | 
			
		||||
@@ -127,16 +127,22 @@ const (
 | 
			
		||||
	ChannelTypeOpenAIMax = 6
 | 
			
		||||
	ChannelTypeOhMyGPT   = 7
 | 
			
		||||
	ChannelTypeCustom    = 8
 | 
			
		||||
	ChannelTypeAILS      = 9
 | 
			
		||||
	ChannelTypeAIProxy   = 10
 | 
			
		||||
	ChannelTypePaLM      = 11
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                            // 0
 | 
			
		||||
	"https://api.openai.com",      // 1
 | 
			
		||||
	"https://openai.api2d.net",    // 2
 | 
			
		||||
	"https://oa.api2d.net",        // 2
 | 
			
		||||
	"",                            // 3
 | 
			
		||||
	"https://api.openai-asia.com", // 4
 | 
			
		||||
	"https://api.openai-sb.com",   // 5
 | 
			
		||||
	"https://api.openaimax.com",   // 6
 | 
			
		||||
	"https://api.ohmygpt.com",     // 7
 | 
			
		||||
	"",                            // 8
 | 
			
		||||
	"https://api.caipacity.com",   // 9
 | 
			
		||||
	"https://api.aiproxy.io",      // 10
 | 
			
		||||
	"",                            // 11
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"runtime"
 | 
			
		||||
@@ -133,6 +134,29 @@ func GetUUID() string {
 | 
			
		||||
	return code
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenerateKey() string {
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
	key := make([]byte, 48)
 | 
			
		||||
	for i := 0; i < 16; i++ {
 | 
			
		||||
		key[i] = keyChars[rand.Intn(len(keyChars))]
 | 
			
		||||
	}
 | 
			
		||||
	uuid_ := GetUUID()
 | 
			
		||||
	for i := 0; i < 32; i++ {
 | 
			
		||||
		c := uuid_[i]
 | 
			
		||||
		if i%2 == 0 && c >= 'a' && c <= 'z' {
 | 
			
		||||
			c = c - 'a' + 'A'
 | 
			
		||||
		}
 | 
			
		||||
		key[i+16] = c
 | 
			
		||||
	}
 | 
			
		||||
	return string(key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTimestamp() int64 {
 | 
			
		||||
	return time.Now().Unix()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										41
									
								
								controller/billing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								controller/billing.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetSubscription(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	quota, err := model.GetUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
			Message: err.Error(),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"error": openAIError,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	subscription := OpenAISubscriptionResponse{
 | 
			
		||||
		Object:             "billing_subscription",
 | 
			
		||||
		HasPaymentMethod:   true,
 | 
			
		||||
		SoftLimitUSD:       float64(quota),
 | 
			
		||||
		HardLimitUSD:       float64(quota),
 | 
			
		||||
		SystemHardLimitUSD: float64(quota),
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, subscription)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUsage(c *gin.Context) {
 | 
			
		||||
	//userId := c.GetInt("id")
 | 
			
		||||
	// TODO: get usage from database
 | 
			
		||||
	usage := OpenAIUsageResponse{
 | 
			
		||||
		Object:     "list",
 | 
			
		||||
		TotalUsage: 0,
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, usage)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										186
									
								
								controller/channel-billing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								controller/channel-billing.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,186 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://github.com/songquanpeng/one-api/issues/79
 | 
			
		||||
 | 
			
		||||
type OpenAISubscriptionResponse struct {
 | 
			
		||||
	Object             string  `json:"object"`
 | 
			
		||||
	HasPaymentMethod   bool    `json:"has_payment_method"`
 | 
			
		||||
	SoftLimitUSD       float64 `json:"soft_limit_usd"`
 | 
			
		||||
	HardLimitUSD       float64 `json:"hard_limit_usd"`
 | 
			
		||||
	SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIUsageDailyCost struct {
 | 
			
		||||
	Timestamp float64 `json:"timestamp"`
 | 
			
		||||
	LineItems []struct {
 | 
			
		||||
		Name string  `json:"name"`
 | 
			
		||||
		Cost float64 `json:"cost"`
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIUsageResponse struct {
 | 
			
		||||
	Object string `json:"object"`
 | 
			
		||||
	//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
 | 
			
		||||
	TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeOpenAI:
 | 
			
		||||
		// do nothing
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	case common.ChannelTypeCustom:
 | 
			
		||||
		baseURL = channel.BaseURL
 | 
			
		||||
	default:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	}
 | 
			
		||||
	url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
 | 
			
		||||
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	req, err := http.NewRequest("GET", url, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	auth := fmt.Sprintf("Bearer %s", channel.Key)
 | 
			
		||||
	req.Header.Add("Authorization", auth)
 | 
			
		||||
	res, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	body, err := io.ReadAll(res.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	err = res.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	subscription := OpenAISubscriptionResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &subscription)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
 | 
			
		||||
	endDate := now.Format("2006-01-02")
 | 
			
		||||
	if !subscription.HasPaymentMethod {
 | 
			
		||||
		startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
 | 
			
		||||
	}
 | 
			
		||||
	url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
 | 
			
		||||
	req, err = http.NewRequest("GET", url, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Add("Authorization", auth)
 | 
			
		||||
	res, err = client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	body, err = io.ReadAll(res.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	err = res.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	usage := OpenAIUsageResponse{}
 | 
			
		||||
	err = json.Unmarshal(body, &usage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	balance := subscription.HardLimitUSD - usage.TotalUsage/100
 | 
			
		||||
	channel.UpdateBalance(balance)
 | 
			
		||||
	return balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelBalance(c *gin.Context) {
 | 
			
		||||
	id, err := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	channel, err := model.GetChannelById(id, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	balance, err := updateChannelBalance(channel)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"balance": balance,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateAllChannelsBalance() error {
 | 
			
		||||
	channels, err := model.GetAllChannels(0, 0, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, channel := range channels {
 | 
			
		||||
		if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// TODO: support Azure
 | 
			
		||||
		if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		balance, err := updateChannelBalance(channel)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		} else {
 | 
			
		||||
			// err is nil & balance <= 0 means quota is used up
 | 
			
		||||
			if balance <= 0 {
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, "余额不足")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateAllChannelsBalance(c *gin.Context) {
 | 
			
		||||
	// TODO: make it async
 | 
			
		||||
	err := updateAllChannelsBalance()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										202
									
								
								controller/channel-test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								controller/channel-test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,202 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
			
		||||
	if request.Model == "" {
 | 
			
		||||
		request.Model = "gpt-3.5-turbo"
 | 
			
		||||
		if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
			request.Model = "gpt-35-turbo"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	requestURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
 | 
			
		||||
	} else {
 | 
			
		||||
		if channel.Type == common.ChannelTypeCustom {
 | 
			
		||||
			requestURL = channel.BaseURL
 | 
			
		||||
		}
 | 
			
		||||
		requestURL += "/v1/chat/completions"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jsonData, err := json.Marshal(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		req.Header.Set("api-key", channel.Key)
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Header.Set("Authorization", "Bearer "+channel.Key)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	var response TextResponse
 | 
			
		||||
	err = json.NewDecoder(resp.Body).Decode(&response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if response.Error.Message != "" || response.Error.Code != "" {
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func buildTestRequest(c *gin.Context) *ChatRequest {
 | 
			
		||||
	model_ := c.Query("model")
 | 
			
		||||
	testRequest := &ChatRequest{
 | 
			
		||||
		Model:     model_,
 | 
			
		||||
		MaxTokens: 1,
 | 
			
		||||
	}
 | 
			
		||||
	testMessage := Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
		Content: "hi",
 | 
			
		||||
	}
 | 
			
		||||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
			
		||||
	return testRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChannel(c *gin.Context) {
 | 
			
		||||
	id, err := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	channel, err := model.GetChannelById(id, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	testRequest := buildTestRequest(c)
 | 
			
		||||
	tik := time.Now()
 | 
			
		||||
	err = testChannel(channel, testRequest)
 | 
			
		||||
	tok := time.Now()
 | 
			
		||||
	milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
	go channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
	consumedTime := float64(milliseconds) / 1000.0
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
			"time":    consumedTime,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"time":    consumedTime,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var testAllChannelsLock sync.Mutex
 | 
			
		||||
var testAllChannelsRunning bool = false
 | 
			
		||||
 | 
			
		||||
// disable & notify
 | 
			
		||||
func disableChannel(channelId int, channelName string, reason string) {
 | 
			
		||||
	if common.RootUserEmail == "" {
 | 
			
		||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
			
		||||
	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | 
			
		||||
	err := common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testAllChannels(c *gin.Context) error {
 | 
			
		||||
	if common.RootUserEmail == "" {
 | 
			
		||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
	testAllChannelsLock.Lock()
 | 
			
		||||
	if testAllChannelsRunning {
 | 
			
		||||
		testAllChannelsLock.Unlock()
 | 
			
		||||
		return errors.New("测试已在运行中")
 | 
			
		||||
	}
 | 
			
		||||
	testAllChannelsRunning = true
 | 
			
		||||
	testAllChannelsLock.Unlock()
 | 
			
		||||
	channels, err := model.GetAllChannels(0, 0, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	testRequest := buildTestRequest(c)
 | 
			
		||||
	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
 | 
			
		||||
	if disableThreshold == 0 {
 | 
			
		||||
		disableThreshold = 10000000 // a impossible value
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		for _, channel := range channels {
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			tik := time.Now()
 | 
			
		||||
			err := testChannel(channel, testRequest)
 | 
			
		||||
			tok := time.Now()
 | 
			
		||||
			milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
			if err != nil || milliseconds > disableThreshold {
 | 
			
		||||
				if milliseconds > disableThreshold {
 | 
			
		||||
					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				}
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
		}
 | 
			
		||||
		err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
			
		||||
		}
 | 
			
		||||
		testAllChannelsLock.Lock()
 | 
			
		||||
		testAllChannelsRunning = false
 | 
			
		||||
		testAllChannelsLock.Unlock()
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAllChannels(c *gin.Context) {
 | 
			
		||||
	err := testAllChannels(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -1,18 +1,12 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(c *gin.Context) {
 | 
			
		||||
@@ -158,187 +152,3 @@ func UpdateChannel(c *gin.Context) {
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request *ChatRequest) error {
 | 
			
		||||
	if request.Model == "" {
 | 
			
		||||
		request.Model = "gpt-3.5-turbo"
 | 
			
		||||
		if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
			request.Model = "gpt-35-turbo"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	requestURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
 | 
			
		||||
	} else {
 | 
			
		||||
		if channel.Type == common.ChannelTypeCustom {
 | 
			
		||||
			requestURL = channel.BaseURL
 | 
			
		||||
		}
 | 
			
		||||
		requestURL += "/v1/chat/completions"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jsonData, err := json.Marshal(request)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		req.Header.Set("api-key", channel.Key)
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Header.Set("Authorization", "Bearer "+channel.Key)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	var response TextResponse
 | 
			
		||||
	err = json.NewDecoder(resp.Body).Decode(&response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if response.Error.Type != "" {
 | 
			
		||||
		return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func buildTestRequest(c *gin.Context) *ChatRequest {
 | 
			
		||||
	model_ := c.Query("model")
 | 
			
		||||
	testRequest := &ChatRequest{
 | 
			
		||||
		Model:     model_,
 | 
			
		||||
		MaxTokens: 1,
 | 
			
		||||
	}
 | 
			
		||||
	testMessage := Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
		Content: "hi",
 | 
			
		||||
	}
 | 
			
		||||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
			
		||||
	return testRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChannel(c *gin.Context) {
 | 
			
		||||
	id, err := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	channel, err := model.GetChannelById(id, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	testRequest := buildTestRequest(c)
 | 
			
		||||
	tik := time.Now()
 | 
			
		||||
	err = testChannel(channel, testRequest)
 | 
			
		||||
	tok := time.Now()
 | 
			
		||||
	milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
	go channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
	consumedTime := float64(milliseconds) / 1000.0
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
			"time":    consumedTime,
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"time":    consumedTime,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var testAllChannelsLock sync.Mutex
 | 
			
		||||
var testAllChannelsRunning bool = false
 | 
			
		||||
 | 
			
		||||
// disable & notify
 | 
			
		||||
func disableChannel(channelId int, channelName string, err error) {
 | 
			
		||||
	if common.RootUserEmail == "" {
 | 
			
		||||
		common.RootUserEmail = model.GetRootUserEmail()
 | 
			
		||||
	}
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | 
			
		||||
	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
 | 
			
		||||
	err = common.SendEmail(subject, common.RootUserEmail, content)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testAllChannels(c *gin.Context) error {
 | 
			
		||||
	testAllChannelsLock.Lock()
 | 
			
		||||
	if testAllChannelsRunning {
 | 
			
		||||
		testAllChannelsLock.Unlock()
 | 
			
		||||
		return errors.New("测试已在运行中")
 | 
			
		||||
	}
 | 
			
		||||
	testAllChannelsRunning = true
 | 
			
		||||
	testAllChannelsLock.Unlock()
 | 
			
		||||
	channels, err := model.GetAllChannels(0, 0, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	testRequest := buildTestRequest(c)
 | 
			
		||||
	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
 | 
			
		||||
	if disableThreshold == 0 {
 | 
			
		||||
		disableThreshold = 10000000 // a impossible value
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		for _, channel := range channels {
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			tik := time.Now()
 | 
			
		||||
			err := testChannel(channel, testRequest)
 | 
			
		||||
			tok := time.Now()
 | 
			
		||||
			milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
			if err != nil || milliseconds > disableThreshold {
 | 
			
		||||
				if milliseconds > disableThreshold {
 | 
			
		||||
					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				}
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err)
 | 
			
		||||
			}
 | 
			
		||||
			channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
		}
 | 
			
		||||
		err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
 | 
			
		||||
		}
 | 
			
		||||
		testAllChannelsLock.Lock()
 | 
			
		||||
		testAllChannelsRunning = false
 | 
			
		||||
		testAllChannelsLock.Unlock()
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAllChannels(c *gin.Context) {
 | 
			
		||||
	err := testAllChannels(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										148
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								controller/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,148 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/models/list
 | 
			
		||||
 | 
			
		||||
type OpenAIModelPermission struct {
 | 
			
		||||
	Id                 string  `json:"id"`
 | 
			
		||||
	Object             string  `json:"object"`
 | 
			
		||||
	Created            int     `json:"created"`
 | 
			
		||||
	AllowCreateEngine  bool    `json:"allow_create_engine"`
 | 
			
		||||
	AllowSampling      bool    `json:"allow_sampling"`
 | 
			
		||||
	AllowLogprobs      bool    `json:"allow_logprobs"`
 | 
			
		||||
	AllowSearchIndices bool    `json:"allow_search_indices"`
 | 
			
		||||
	AllowView          bool    `json:"allow_view"`
 | 
			
		||||
	AllowFineTuning    bool    `json:"allow_fine_tuning"`
 | 
			
		||||
	Organization       string  `json:"organization"`
 | 
			
		||||
	Group              *string `json:"group"`
 | 
			
		||||
	IsBlocking         bool    `json:"is_blocking"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIModels struct {
 | 
			
		||||
	Id         string                  `json:"id"`
 | 
			
		||||
	Object     string                  `json:"object"`
 | 
			
		||||
	Created    int                     `json:"created"`
 | 
			
		||||
	OwnedBy    string                  `json:"owned_by"`
 | 
			
		||||
	Permission []OpenAIModelPermission `json:"permission"`
 | 
			
		||||
	Root       string                  `json:"root"`
 | 
			
		||||
	Parent     *string                 `json:"parent"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var openAIModels []OpenAIModels
 | 
			
		||||
var openAIModelsMap map[string]OpenAIModels
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	var permission []OpenAIModelPermission
 | 
			
		||||
	permission = append(permission, OpenAIModelPermission{
 | 
			
		||||
		Id:                 "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
 | 
			
		||||
		Object:             "model_permission",
 | 
			
		||||
		Created:            1626777600,
 | 
			
		||||
		AllowCreateEngine:  true,
 | 
			
		||||
		AllowSampling:      true,
 | 
			
		||||
		AllowLogprobs:      true,
 | 
			
		||||
		AllowSearchIndices: false,
 | 
			
		||||
		AllowView:          true,
 | 
			
		||||
		AllowFineTuning:    false,
 | 
			
		||||
		Organization:       "*",
 | 
			
		||||
		Group:              nil,
 | 
			
		||||
		IsBlocking:         false,
 | 
			
		||||
	})
 | 
			
		||||
	// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
	openAIModels = []OpenAIModels{
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo-0301",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-3.5-turbo-0301",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-0314",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-0314",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-32k",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-32k",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-4-32k-0314",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "gpt-4-32k-0314",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-embedding-ada-002",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-embedding-ada-002",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
		openAIModelsMap[model.Id] = model
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListModels(c *gin.Context) {
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"object": "list",
 | 
			
		||||
		"data":   openAIModels,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RetrieveModel(c *gin.Context) {
 | 
			
		||||
	modelId := c.Param("model")
 | 
			
		||||
	if model, ok := openAIModelsMap[modelId]; ok {
 | 
			
		||||
		c.JSON(200, model)
 | 
			
		||||
	} else {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("The model '%s' does not exist", modelId),
 | 
			
		||||
			Type:    "invalid_request_error",
 | 
			
		||||
			Param:   "model",
 | 
			
		||||
			Code:    "model_not_found",
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"error": openAIError,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										59
									
								
								controller/relay-palm.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								controller/relay-palm.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PaLMChatMessage struct {
 | 
			
		||||
	Author  string `json:"author"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PaLMFilter struct {
 | 
			
		||||
	Reason  string `json:"reason"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
			
		||||
type PaLMChatRequest struct {
 | 
			
		||||
	Prompt         []Message `json:"prompt"`
 | 
			
		||||
	Temperature    float64   `json:"temperature"`
 | 
			
		||||
	CandidateCount int       `json:"candidateCount"`
 | 
			
		||||
	TopP           float64   `json:"topP"`
 | 
			
		||||
	TopK           int       `json:"topK"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 | 
			
		||||
type PaLMChatResponse struct {
 | 
			
		||||
	Candidates []Message    `json:"candidates"`
 | 
			
		||||
	Messages   []Message    `json:"messages"`
 | 
			
		||||
	Filters    []PaLMFilter `json:"filters"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
 | 
			
		||||
	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
 | 
			
		||||
	for _, message := range openAIRequest.Messages {
 | 
			
		||||
		var author string
 | 
			
		||||
		if message.Role == "user" {
 | 
			
		||||
			author = "0"
 | 
			
		||||
		} else {
 | 
			
		||||
			author = "1"
 | 
			
		||||
		}
 | 
			
		||||
		messages = append(messages, PaLMChatMessage{
 | 
			
		||||
			Author:  author,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	request := PaLMChatRequest{
 | 
			
		||||
		Prompt:         nil,
 | 
			
		||||
		Temperature:    openAIRequest.Temperature,
 | 
			
		||||
		CandidateCount: openAIRequest.N,
 | 
			
		||||
		TopP:           openAIRequest.TopP,
 | 
			
		||||
		TopK:           openAIRequest.MaxTokens,
 | 
			
		||||
	}
 | 
			
		||||
	// TODO: forward request to PaLM & convert response
 | 
			
		||||
	fmt.Print(request)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										65
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								controller/relay-utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
			
		||||
 | 
			
		||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
			
		||||
	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
			
		||||
		return tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
 | 
			
		||||
		tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tokenEncoderMap[model] = tokenEncoder
 | 
			
		||||
	return tokenEncoder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func countTokenMessages(messages []Message, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	// Reference:
 | 
			
		||||
	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 | 
			
		||||
	// https://github.com/pkoukk/tiktoken-go/issues/6
 | 
			
		||||
	//
 | 
			
		||||
	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
 | 
			
		||||
	var tokensPerMessage int
 | 
			
		||||
	var tokensPerName int
 | 
			
		||||
	if strings.HasPrefix(model, "gpt-3.5") {
 | 
			
		||||
		tokensPerMessage = 4
 | 
			
		||||
		tokensPerName = -1 // If there's a name, the role is omitted
 | 
			
		||||
	} else if strings.HasPrefix(model, "gpt-4") {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		tokensPerMessage = 3
 | 
			
		||||
		tokensPerName = 1
 | 
			
		||||
	}
 | 
			
		||||
	tokenNum := 0
 | 
			
		||||
	for _, message := range messages {
 | 
			
		||||
		tokenNum += tokensPerMessage
 | 
			
		||||
		tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
 | 
			
		||||
		tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
 | 
			
		||||
		if message.Name != nil {
 | 
			
		||||
			tokenNum += tokensPerName
 | 
			
		||||
			tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
 | 
			
		||||
	return tokenNum
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func countTokenText(text string, model string) int {
 | 
			
		||||
	tokenEncoder := getTokenEncoder(model)
 | 
			
		||||
	token := tokenEncoder.Encode(text, nil, nil)
 | 
			
		||||
	return len(token)
 | 
			
		||||
}
 | 
			
		||||
@@ -4,10 +4,8 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
@@ -16,8 +14,22 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role    string `json:"role"`
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
	Role    string  `json:"role"`
 | 
			
		||||
	Content string  `json:"content"`
 | 
			
		||||
	Name    *string `json:"name,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/chat
 | 
			
		||||
 | 
			
		||||
type GeneralOpenAIRequest struct {
 | 
			
		||||
	Model       string    `json:"model"`
 | 
			
		||||
	Messages    []Message `json:"messages"`
 | 
			
		||||
	Prompt      string    `json:"prompt"`
 | 
			
		||||
	Stream      bool      `json:"stream"`
 | 
			
		||||
	MaxTokens   int       `json:"max_tokens"`
 | 
			
		||||
	Temperature float64   `json:"temperature"`
 | 
			
		||||
	TopP        float64   `json:"top_p"`
 | 
			
		||||
	N           int       `json:"n"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
@@ -47,6 +59,11 @@ type OpenAIError struct {
 | 
			
		||||
	Code    string `json:"code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIErrorWithStatusCode struct {
 | 
			
		||||
	OpenAIError
 | 
			
		||||
	StatusCode int `json:"status_code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TextResponse struct {
 | 
			
		||||
	Usage `json:"usage"`
 | 
			
		||||
	Error OpenAIError `json:"error"`
 | 
			
		||||
@@ -61,47 +78,55 @@ type StreamResponse struct {
 | 
			
		||||
	} `json:"choices"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
 | 
			
		||||
 | 
			
		||||
func countToken(text string) int {
 | 
			
		||||
	token := tokenEncoder.Encode(text, nil, nil)
 | 
			
		||||
	return len(token)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Relay(c *gin.Context) {
 | 
			
		||||
	err := relayHelper(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"error": gin.H{
 | 
			
		||||
				"message": err.Error(),
 | 
			
		||||
				"type":    "one_api_error",
 | 
			
		||||
			},
 | 
			
		||||
		if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
			err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
			"error": err.OpenAIError,
 | 
			
		||||
		})
 | 
			
		||||
		if common.AutomaticDisableChannelEnabled {
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key") {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelName := c.GetString("channel_name")
 | 
			
		||||
			disableChannel(channelId, channelName, err)
 | 
			
		||||
			disableChannel(channelId, channelName, err.Message)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) error {
 | 
			
		||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	openAIError := OpenAIError{
 | 
			
		||||
		Message: err.Error(),
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
		Code:    code,
 | 
			
		||||
	}
 | 
			
		||||
	return &OpenAIErrorWithStatusCode{
 | 
			
		||||
		OpenAIError: openAIError,
 | 
			
		||||
		StatusCode:  statusCode,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	var textRequest TextRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure {
 | 
			
		||||
	var textRequest GeneralOpenAIRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 | 
			
		||||
		requestBody, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = c.Request.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(requestBody, &textRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		// Reset request body
 | 
			
		||||
		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 | 
			
		||||
@@ -129,12 +154,12 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		model_ = strings.TrimSuffix(model_, "-0301")
 | 
			
		||||
		model_ = strings.TrimSuffix(model_, "-0314")
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 | 
			
		||||
	} else if channelType == common.ChannelTypePaLM {
 | 
			
		||||
		err := relayPaLM(textRequest, c)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var promptText string
 | 
			
		||||
	for _, message := range textRequest.Messages {
 | 
			
		||||
		promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
 | 
			
		||||
	}
 | 
			
		||||
	promptTokens := countToken(promptText) + 3
 | 
			
		||||
 | 
			
		||||
	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	if textRequest.MaxTokens != 0 {
 | 
			
		||||
		preConsumedTokens = promptTokens + textRequest.MaxTokens
 | 
			
		||||
@@ -144,12 +169,12 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	if channelType == common.ChannelTypeAzure {
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
@@ -164,18 +189,18 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusOK)
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	isStream := resp.Header.Get("Content-Type") == "text/event-stream"
 | 
			
		||||
	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
	var streamResponseText string
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
@@ -187,8 +212,8 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
				completionRatio = 2
 | 
			
		||||
			}
 | 
			
		||||
			if isStream {
 | 
			
		||||
				completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
 | 
			
		||||
				quota = promptTokens + countToken(completionText)*completionRatio
 | 
			
		||||
				responseTokens := countTokenText(streamResponseText, textRequest.Model)
 | 
			
		||||
				quota = promptTokens + responseTokens*completionRatio
 | 
			
		||||
			} else {
 | 
			
		||||
				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 | 
			
		||||
			}
 | 
			
		||||
@@ -223,6 +248,10 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		go func() {
 | 
			
		||||
			for scanner.Scan() {
 | 
			
		||||
				data := scanner.Text()
 | 
			
		||||
				if len(data) < 6 { // must be something wrong!
 | 
			
		||||
					common.SysError("Invalid stream response: " + data)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				data = data[6:]
 | 
			
		||||
				if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
@@ -243,6 +272,7 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
		c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
		c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
		c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
		c.Stream(func(w io.Writer) bool {
 | 
			
		||||
			select {
 | 
			
		||||
			case data := <-dataChan:
 | 
			
		||||
@@ -257,50 +287,60 @@ func relayHelper(c *gin.Context) error {
 | 
			
		||||
		})
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	} else {
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "read_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
 | 
			
		||||
			}
 | 
			
		||||
			if textResponse.Error.Type != "" {
 | 
			
		||||
				return errors.New(fmt.Sprintf("type %s, code %s, message %s",
 | 
			
		||||
					textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
 | 
			
		||||
				return &OpenAIErrorWithStatusCode{
 | 
			
		||||
					OpenAIError: textResponse.Error,
 | 
			
		||||
					StatusCode:  resp.StatusCode,
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// Reset response body
 | 
			
		||||
			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
		}
 | 
			
		||||
		// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
		// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
		// So the client will be confused by the response.
 | 
			
		||||
		// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
		_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusOK)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayNotImplemented(c *gin.Context) {
 | 
			
		||||
	err := OpenAIError{
 | 
			
		||||
		Message: "API not implemented",
 | 
			
		||||
		Type:    "one_api_error",
 | 
			
		||||
		Param:   "",
 | 
			
		||||
		Code:    "api_not_implemented",
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": "Not Implemented",
 | 
			
		||||
			"type":    "one_api_error",
 | 
			
		||||
		},
 | 
			
		||||
		"error": err,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -119,7 +119,7 @@ func AddToken(c *gin.Context) {
 | 
			
		||||
	cleanToken := model.Token{
 | 
			
		||||
		UserId:         c.GetInt("id"),
 | 
			
		||||
		Name:           token.Name,
 | 
			
		||||
		Key:            common.GetUUID(),
 | 
			
		||||
		Key:            common.GenerateKey(),
 | 
			
		||||
		CreatedTime:    common.GetTimestamp(),
 | 
			
		||||
		AccessedTime:   common.GetTimestamp(),
 | 
			
		||||
		ExpiredTime:    token.ExpiredTime,
 | 
			
		||||
 
 | 
			
		||||
@@ -467,6 +467,13 @@ func CreateUser(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if err := common.Validate.Struct(&user); err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "输入不合法 " + err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if user.DisplayName == "" {
 | 
			
		||||
		user.DisplayName = user.Username
 | 
			
		||||
	}
 | 
			
		||||
@@ -648,6 +655,9 @@ func EmailBind(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if user.Role == common.RoleRootUser {
 | 
			
		||||
		common.RootUserEmail = email
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
 
 | 
			
		||||
@@ -9,8 +9,8 @@ services:
 | 
			
		||||
    ports:
 | 
			
		||||
      - "3000:3000"
 | 
			
		||||
    volumes:
 | 
			
		||||
      - /home/ubuntu/data/one-api:/data
 | 
			
		||||
      - /home/ubuntu/data/one-api/logs:/app/logs
 | 
			
		||||
      - ./data:/data
 | 
			
		||||
      - ./logs:/app/logs
 | 
			
		||||
    # environment:
 | 
			
		||||
    #   REDIS_CONN_STRING: redis://default:redispw@localhost:49153
 | 
			
		||||
    #   SESSION_SECRET: random_string
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										19
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								go.mod
									
									
									
									
									
								
							@@ -8,12 +8,12 @@ require (
 | 
			
		||||
	github.com/gin-contrib/gzip v0.0.6
 | 
			
		||||
	github.com/gin-contrib/sessions v0.0.5
 | 
			
		||||
	github.com/gin-contrib/static v0.0.1
 | 
			
		||||
	github.com/gin-gonic/gin v1.9.0
 | 
			
		||||
	github.com/go-playground/validator/v10 v10.12.0
 | 
			
		||||
	github.com/gin-gonic/gin v1.9.1
 | 
			
		||||
	github.com/go-playground/validator/v10 v10.14.0
 | 
			
		||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.1
 | 
			
		||||
	golang.org/x/crypto v0.8.0
 | 
			
		||||
	golang.org/x/crypto v0.9.0
 | 
			
		||||
	gorm.io/driver/mysql v1.4.3
 | 
			
		||||
	gorm.io/driver/sqlite v1.4.3
 | 
			
		||||
	gorm.io/gorm v1.24.0
 | 
			
		||||
@@ -21,11 +21,12 @@ require (
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.8.8 // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 | 
			
		||||
	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
 | 
			
		||||
	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
 | 
			
		||||
	github.com/dlclark/regexp2 v1.8.1 // indirect
 | 
			
		||||
	github.com/gabriel-vasile/mimetype v1.4.2 // indirect
 | 
			
		||||
	github.com/gin-contrib/sse v0.1.0 // indirect
 | 
			
		||||
	github.com/go-playground/locales v0.14.1 // indirect
 | 
			
		||||
	github.com/go-playground/universal-translator v0.18.1 // indirect
 | 
			
		||||
@@ -39,17 +40,17 @@ require (
 | 
			
		||||
	github.com/jinzhu/now v1.1.5 // indirect
 | 
			
		||||
	github.com/json-iterator/go v1.1.12 // indirect
 | 
			
		||||
	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 | 
			
		||||
	github.com/leodido/go-urn v1.2.3 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.18 // indirect
 | 
			
		||||
	github.com/leodido/go-urn v1.2.4 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.19 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
 | 
			
		||||
	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 | 
			
		||||
	github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
			
		||||
	github.com/pelletier/go-toml/v2 v2.0.7 // indirect
 | 
			
		||||
	github.com/pelletier/go-toml/v2 v2.0.8 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.11 // indirect
 | 
			
		||||
	golang.org/x/arch v0.3.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.9.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.10.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.9.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.30.0 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										41
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,8 +1,8 @@
 | 
			
		||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04=
 | 
			
		||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
 | 
			
		||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
			
		||||
github.com/bytedance/sonic v1.8.8 h1:Kj4AYbZSeENfyXicsYppYKO0K2YWab+i2UTSY7Ukz9Q=
 | 
			
		||||
github.com/bytedance/sonic v1.8.8/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
 | 
			
		||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 | 
			
		||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
 | 
			
		||||
@@ -17,6 +17,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
 | 
			
		||||
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
 | 
			
		||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
 | 
			
		||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
 | 
			
		||||
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
 | 
			
		||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
 | 
			
		||||
@@ -29,8 +31,8 @@ github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Sw
 | 
			
		||||
github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs=
 | 
			
		||||
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
 | 
			
		||||
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
 | 
			
		||||
github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8=
 | 
			
		||||
github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k=
 | 
			
		||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
 | 
			
		||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 | 
			
		||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
 | 
			
		||||
@@ -43,8 +45,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
 | 
			
		||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.12.0 h1:E4gtWgxWxp8YSxExrQFv5BpCahla0PVF2oTTEYaWQGI=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.12.0/go.mod h1:hCAPuzYvKdP33pxWa+2+6AIKXEKqjIUyqsNCtbsSJrA=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
 | 
			
		||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
 | 
			
		||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
 | 
			
		||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
 | 
			
		||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
 | 
			
		||||
@@ -89,12 +91,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
 | 
			
		||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
 | 
			
		||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
 | 
			
		||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
 | 
			
		||||
github.com/leodido/go-urn v1.2.3 h1:6BE2vPT0lqoz3fmOesHZiaiFh7889ssCo2GMvLCfiuA=
 | 
			
		||||
github.com/leodido/go-urn v1.2.3/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
 | 
			
		||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
 | 
			
		||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
 | 
			
		||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 | 
			
		||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
 | 
			
		||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
 | 
			
		||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
 | 
			
		||||
@@ -108,8 +110,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
 | 
			
		||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
 | 
			
		||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.0.7 h1:muncTPStnKRos5dpVKULv2FVd4bMOhNePj9CjgDb8Us=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.0.7/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
 | 
			
		||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
 | 
			
		||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6nXo=
 | 
			
		||||
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
 | 
			
		||||
@@ -128,8 +130,9 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
 | 
			
		||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 | 
			
		||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 | 
			
		||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 | 
			
		||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
 | 
			
		||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 | 
			
		||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
 | 
			
		||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
 | 
			
		||||
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
 | 
			
		||||
@@ -142,11 +145,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
 | 
			
		||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
 | 
			
		||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
 | 
			
		||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
 | 
			
		||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
 | 
			
		||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
 | 
			
		||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 | 
			
		||||
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
 | 
			
		||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
 | 
			
		||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
 | 
			
		||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 | 
			
		||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
@@ -154,8 +157,8 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
 | 
			
		||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
 | 
			
		||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
 | 
			
		||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 | 
			
		||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 | 
			
		||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.go
									
									
									
									
									
								
							@@ -47,6 +47,13 @@ func main() {
 | 
			
		||||
 | 
			
		||||
	// Initialize options
 | 
			
		||||
	model.InitOptionMap()
 | 
			
		||||
	if os.Getenv("SYNC_FREQUENCY") != "" {
 | 
			
		||||
		frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.FatalLog(err)
 | 
			
		||||
		}
 | 
			
		||||
		go model.SyncOptions(frequency)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
 
 | 
			
		||||
@@ -85,6 +85,8 @@ func RootAuth() func(c *gin.Context) {
 | 
			
		||||
func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		key := c.Request.Header.Get("Authorization")
 | 
			
		||||
		key = strings.TrimPrefix(key, "Bearer ")
 | 
			
		||||
		key = strings.TrimPrefix(key, "sk-")
 | 
			
		||||
		parts := strings.Split(key, "-")
 | 
			
		||||
		key = parts[0]
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,11 @@ import (
 | 
			
		||||
 | 
			
		||||
func Cache() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		c.Header("Cache-Control", "max-age=604800") // one week
 | 
			
		||||
		if c.Request.RequestURI == "/" {
 | 
			
		||||
			c.Header("Cache-Control", "no-cache")
 | 
			
		||||
		} else {
 | 
			
		||||
			c.Header("Cache-Control", "max-age=604800") // one week
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,17 +6,19 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Channel struct {
 | 
			
		||||
	Id           int    `json:"id"`
 | 
			
		||||
	Type         int    `json:"type" gorm:"default:0"`
 | 
			
		||||
	Key          string `json:"key" gorm:"not null"`
 | 
			
		||||
	Status       int    `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name         string `json:"name" gorm:"index"`
 | 
			
		||||
	Weight       int    `json:"weight"`
 | 
			
		||||
	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	TestTime     int64  `json:"test_time" gorm:"bigint"`
 | 
			
		||||
	ResponseTime int    `json:"response_time"` // in milliseconds
 | 
			
		||||
	BaseURL      string `json:"base_url" gorm:"column:base_url"`
 | 
			
		||||
	Other        string `json:"other"`
 | 
			
		||||
	Id                 int     `json:"id"`
 | 
			
		||||
	Type               int     `json:"type" gorm:"default:0"`
 | 
			
		||||
	Key                string  `json:"key" gorm:"not null"`
 | 
			
		||||
	Status             int     `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name               string  `json:"name" gorm:"index"`
 | 
			
		||||
	Weight             int     `json:"weight"`
 | 
			
		||||
	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	TestTime           int64   `json:"test_time" gorm:"bigint"`
 | 
			
		||||
	ResponseTime       int     `json:"response_time"` // in milliseconds
 | 
			
		||||
	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
 | 
			
		||||
	Other              string  `json:"other"`
 | 
			
		||||
	Balance            float64 `json:"balance"` // in USD
 | 
			
		||||
	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
@@ -86,6 +88,16 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) UpdateBalance(balance float64) {
 | 
			
		||||
	err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
 | 
			
		||||
		BalanceUpdatedTime: common.GetTimestamp(),
 | 
			
		||||
		Balance:            balance,
 | 
			
		||||
	}).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update balance: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) Delete() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Delete(channel).Error
 | 
			
		||||
 
 | 
			
		||||
@@ -26,6 +26,7 @@ func createRootAccountIfNeed() error {
 | 
			
		||||
			Status:      common.UserStatusEnabled,
 | 
			
		||||
			DisplayName: "Root User",
 | 
			
		||||
			AccessToken: common.GetUUID(),
 | 
			
		||||
			Quota:       100000000,
 | 
			
		||||
		}
 | 
			
		||||
		DB.Create(&rootUser)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Option struct {
 | 
			
		||||
@@ -59,6 +60,10 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 | 
			
		||||
	common.OptionMap["TopUpLink"] = common.TopUpLink
 | 
			
		||||
	common.OptionMapRWMutex.Unlock()
 | 
			
		||||
	loadOptionsFromDatabase()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func loadOptionsFromDatabase() {
 | 
			
		||||
	options, _ := AllOption()
 | 
			
		||||
	for _, option := range options {
 | 
			
		||||
		err := updateOptionMap(option.Key, option.Value)
 | 
			
		||||
@@ -68,6 +73,14 @@ func InitOptionMap() {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SyncOptions(frequency int) {
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Duration(frequency) * time.Second)
 | 
			
		||||
		common.SysLog("Syncing options from database")
 | 
			
		||||
		loadOptionsFromDatabase()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateOption(key string, value string) error {
 | 
			
		||||
	// Save to database first
 | 
			
		||||
	option := Option{
 | 
			
		||||
 
 | 
			
		||||
@@ -6,13 +6,12 @@ import (
 | 
			
		||||
	_ "gorm.io/driver/sqlite"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Token struct {
 | 
			
		||||
	Id             int    `json:"id"`
 | 
			
		||||
	UserId         int    `json:"user_id"`
 | 
			
		||||
	Key            string `json:"key" gorm:"type:char(32);uniqueIndex"`
 | 
			
		||||
	Key            string `json:"key" gorm:"type:char(48);uniqueIndex"`
 | 
			
		||||
	Status         int    `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name           string `json:"name" gorm:"index" `
 | 
			
		||||
	CreatedTime    int64  `json:"created_time" gorm:"bigint"`
 | 
			
		||||
@@ -38,7 +37,6 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
			
		||||
	if key == "" {
 | 
			
		||||
		return nil, errors.New("未提供 token")
 | 
			
		||||
	}
 | 
			
		||||
	key = strings.Replace(key, "Bearer ", "", 1)
 | 
			
		||||
	token = &Token{}
 | 
			
		||||
	err = DB.Where("`key` = ?", key).First(token).Error
 | 
			
		||||
	if err == nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -19,8 +19,7 @@ type User struct {
 | 
			
		||||
	Email            string `json:"email" gorm:"index" validate:"max=50"`
 | 
			
		||||
	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
			
		||||
	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	Balance          int    `json:"balance" gorm:"type:int;default:0"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"type:int;default:0"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -66,6 +66,8 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
			channelRoute.GET("/:id", controller.GetChannel)
 | 
			
		||||
			channelRoute.GET("/test", controller.TestAllChannels)
 | 
			
		||||
			channelRoute.GET("/test/:id", controller.TestChannel)
 | 
			
		||||
			channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
 | 
			
		||||
			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
 | 
			
		||||
			channelRoute.POST("/", controller.AddChannel)
 | 
			
		||||
			channelRoute.PUT("/", controller.UpdateChannel)
 | 
			
		||||
			channelRoute.DELETE("/:id", controller.DeleteChannel)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,11 +8,14 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetDashboardRouter(router *gin.Engine) {
 | 
			
		||||
	apiRouter := router.Group("/dashboard")
 | 
			
		||||
	apiRouter := router.Group("/")
 | 
			
		||||
	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	apiRouter.Use(middleware.GlobalAPIRateLimit())
 | 
			
		||||
	apiRouter.Use(middleware.TokenAuth())
 | 
			
		||||
	{
 | 
			
		||||
		apiRouter.GET("/billing/credit_grants", controller.GetTokenStatus)
 | 
			
		||||
		apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
 | 
			
		||||
		apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
 | 
			
		||||
		apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
 | 
			
		||||
		apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,24 @@ package router
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"embed"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
	SetApiRouter(router)
 | 
			
		||||
	SetDashboardRouter(router)
 | 
			
		||||
	SetRelayRouter(router)
 | 
			
		||||
	setWebRouter(router, buildFS, indexPage)
 | 
			
		||||
	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
 | 
			
		||||
	if frontendBaseUrl == "" {
 | 
			
		||||
		SetWebRouter(router, buildFS, indexPage)
 | 
			
		||||
	} else {
 | 
			
		||||
		frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
 | 
			
		||||
		router.NoRoute(func(c *gin.Context) {
 | 
			
		||||
			c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,8 +11,8 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	relayV1Router := router.Group("/v1")
 | 
			
		||||
	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 | 
			
		||||
	{
 | 
			
		||||
		relayV1Router.GET("/models", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/models/:model", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/models", controller.ListModels)
 | 
			
		||||
		relayV1Router.GET("/models/:model", controller.RetrieveModel)
 | 
			
		||||
		relayV1Router.POST("/completions", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/chat/completions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/edits", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -10,12 +10,13 @@ import (
 | 
			
		||||
	"one-api/middleware"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func setWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 | 
			
		||||
	router.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	router.Use(middleware.GlobalWebRateLimit())
 | 
			
		||||
	router.Use(middleware.Cache())
 | 
			
		||||
	router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
 | 
			
		||||
	router.NoRoute(func(c *gin.Context) {
 | 
			
		||||
		c.Header("Cache-Control", "no-cache")
 | 
			
		||||
		c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,7 @@ const ChannelsTable = () => {
 | 
			
		||||
  const [activePage, setActivePage] = useState(1);
 | 
			
		||||
  const [searchKeyword, setSearchKeyword] = useState('');
 | 
			
		||||
  const [searching, setSearching] = useState(false);
 | 
			
		||||
  const [updatingBalance, setUpdatingBalance] = useState(false);
 | 
			
		||||
 | 
			
		||||
  const loadChannels = async (startIdx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/?p=${startIdx}`);
 | 
			
		||||
@@ -63,7 +64,7 @@ const ChannelsTable = () => {
 | 
			
		||||
  const refresh = async () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    await loadChannels(0);
 | 
			
		||||
  }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    loadChannels(0)
 | 
			
		||||
@@ -127,7 +128,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
  const renderResponseTime = (responseTime) => {
 | 
			
		||||
    let time = responseTime / 1000;
 | 
			
		||||
    time = time.toFixed(2) + " 秒";
 | 
			
		||||
    time = time.toFixed(2) + ' 秒';
 | 
			
		||||
    if (responseTime === 0) {
 | 
			
		||||
      return <Label basic color='grey'>未测试</Label>;
 | 
			
		||||
    } else if (responseTime <= 1000) {
 | 
			
		||||
@@ -179,11 +180,38 @@ const ChannelsTable = () => {
 | 
			
		||||
    const res = await API.get(`/api/channel/test`);
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showInfo("已成功开始测试所有已启用通道,请刷新页面查看结果。");
 | 
			
		||||
      showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。');
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const updateChannelBalance = async (id, name, idx) => {
 | 
			
		||||
    const res = await API.get(`/api/channel/update_balance/${id}/`);
 | 
			
		||||
    const { success, message, balance } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      let newChannels = [...channels];
 | 
			
		||||
      let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
 | 
			
		||||
      newChannels[realIdx].balance = balance;
 | 
			
		||||
      newChannels[realIdx].balance_updated_time = Date.now() / 1000;
 | 
			
		||||
      setChannels(newChannels);
 | 
			
		||||
      showInfo(`通道 ${name} 余额更新成功!`);
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const updateAllChannelsBalance = async () => {
 | 
			
		||||
    setUpdatingBalance(true);
 | 
			
		||||
    const res = await API.get(`/api/channel/update_balance`);
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showInfo('已更新完毕所有已启用通道余额!');
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
    }
 | 
			
		||||
    setUpdatingBalance(false);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const handleKeywordChange = async (e, { value }) => {
 | 
			
		||||
    setSearchKeyword(value.trim());
 | 
			
		||||
@@ -263,10 +291,10 @@ const ChannelsTable = () => {
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortChannel('test_time');
 | 
			
		||||
                sortChannel('balance');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              测试时间
 | 
			
		||||
              余额
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
			
		||||
          </Table.Row>
 | 
			
		||||
@@ -286,8 +314,22 @@ const ChannelsTable = () => {
 | 
			
		||||
                  <Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderType(channel.type)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderResponseTime(channel.response_time)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{channel.test_time ? renderTimestamp(channel.test_time) : "未测试"}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={channel.test_time ? renderTimestamp(channel.test_time) : '未测试'}
 | 
			
		||||
                      key={channel.id}
 | 
			
		||||
                      trigger={renderResponseTime(channel.response_time)}
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
 | 
			
		||||
                      key={channel.id}
 | 
			
		||||
                      trigger={<span>${channel.balance.toFixed(2)}</span>}
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <div>
 | 
			
		||||
                      <Button
 | 
			
		||||
@@ -299,6 +341,16 @@ const ChannelsTable = () => {
 | 
			
		||||
                      >
 | 
			
		||||
                        测试
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      <Button
 | 
			
		||||
                        size={'small'}
 | 
			
		||||
                        positive
 | 
			
		||||
                        loading={updatingBalance}
 | 
			
		||||
                        onClick={() => {
 | 
			
		||||
                          updateChannelBalance(channel.id, channel.name, idx);
 | 
			
		||||
                        }}
 | 
			
		||||
                      >
 | 
			
		||||
                        更新余额
 | 
			
		||||
                      </Button>
 | 
			
		||||
                      <Popup
 | 
			
		||||
                        trigger={
 | 
			
		||||
                          <Button size='small' negative>
 | 
			
		||||
@@ -353,6 +405,7 @@ const ChannelsTable = () => {
 | 
			
		||||
              <Button size='small' loading={loading} onClick={testAllChannels}>
 | 
			
		||||
                测试所有已启用通道
 | 
			
		||||
              </Button>
 | 
			
		||||
              <Button size='small' onClick={updateAllChannelsBalance} loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
 | 
			
		||||
              <Pagination
 | 
			
		||||
                floated='right'
 | 
			
		||||
                activePage={activePage}
 | 
			
		||||
 
 | 
			
		||||
@@ -112,13 +112,17 @@ const PersonalSetting = () => {
 | 
			
		||||
      <Button onClick={generateAccessToken}>生成系统访问令牌</Button>
 | 
			
		||||
      <Divider />
 | 
			
		||||
      <Header as='h3'>账号绑定</Header>
 | 
			
		||||
      <Button
 | 
			
		||||
        onClick={() => {
 | 
			
		||||
          setShowWeChatBindModal(true);
 | 
			
		||||
        }}
 | 
			
		||||
      >
 | 
			
		||||
        绑定微信账号
 | 
			
		||||
      </Button>
 | 
			
		||||
      {
 | 
			
		||||
        status.wechat_login && (
 | 
			
		||||
          <Button
 | 
			
		||||
            onClick={() => {
 | 
			
		||||
              setShowWeChatBindModal(true);
 | 
			
		||||
            }}
 | 
			
		||||
          >
 | 
			
		||||
            绑定微信账号
 | 
			
		||||
          </Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Modal
 | 
			
		||||
        onClose={() => setShowWeChatBindModal(false)}
 | 
			
		||||
        onOpen={() => setShowWeChatBindModal(true)}
 | 
			
		||||
@@ -148,7 +152,11 @@ const PersonalSetting = () => {
 | 
			
		||||
          </Modal.Description>
 | 
			
		||||
        </Modal.Content>
 | 
			
		||||
      </Modal>
 | 
			
		||||
      <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
 | 
			
		||||
      {
 | 
			
		||||
        status.github_oauth && (
 | 
			
		||||
          <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Button
 | 
			
		||||
        onClick={() => {
 | 
			
		||||
          setShowEmailBindModal(true);
 | 
			
		||||
 
 | 
			
		||||
@@ -238,11 +238,12 @@ const TokensTable = () => {
 | 
			
		||||
                        size={'small'}
 | 
			
		||||
                        positive
 | 
			
		||||
                        onClick={async () => {
 | 
			
		||||
                          if (await copy(token.key)) {
 | 
			
		||||
                          let key = "sk-" + token.key;
 | 
			
		||||
                          if (await copy(key)) {
 | 
			
		||||
                            showSuccess('已复制到剪贴板!');
 | 
			
		||||
                          } else {
 | 
			
		||||
                            showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
 | 
			
		||||
                            setSearchKeyword(token.key);
 | 
			
		||||
                            setSearchKeyword(key);
 | 
			
		||||
                          }
 | 
			
		||||
                        }}
 | 
			
		||||
                      >
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showSuccess } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import { ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import { renderText } from '../helpers/render';
 | 
			
		||||
 | 
			
		||||
function renderRole(role) {
 | 
			
		||||
  switch (role) {
 | 
			
		||||
@@ -64,7 +65,7 @@ const UsersTable = () => {
 | 
			
		||||
    (async () => {
 | 
			
		||||
      const res = await API.post('/api/user/manage', {
 | 
			
		||||
        username,
 | 
			
		||||
        action,
 | 
			
		||||
        action
 | 
			
		||||
      });
 | 
			
		||||
      const { success, message } = res.data;
 | 
			
		||||
      if (success) {
 | 
			
		||||
@@ -161,18 +162,18 @@ const UsersTable = () => {
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortUser('username');
 | 
			
		||||
                sortUser('id');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              用户名
 | 
			
		||||
              ID
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
              onClick={() => {
 | 
			
		||||
                sortUser('display_name');
 | 
			
		||||
                sortUser('username');
 | 
			
		||||
              }}
 | 
			
		||||
            >
 | 
			
		||||
              显示名称
 | 
			
		||||
              用户名
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
              style={{ cursor: 'pointer' }}
 | 
			
		||||
@@ -220,9 +221,17 @@ const UsersTable = () => {
 | 
			
		||||
              if (user.deleted) return <></>;
 | 
			
		||||
              return (
 | 
			
		||||
                <Table.Row key={user.id}>
 | 
			
		||||
                  <Table.Cell>{user.username}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.display_name}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.email ? user.email : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.id}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                      content={user.email ? user.email : '未绑定邮箱地址'}
 | 
			
		||||
                      key={user.display_name}
 | 
			
		||||
                      header={user.display_name ? user.display_name : user.username}
 | 
			
		||||
                      trigger={<span>{renderText(user.username, 10)}</span>}
 | 
			
		||||
                      hoverable
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{user.quota}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderRole(user.role)}</Table.Cell>
 | 
			
		||||
                  <Table.Cell>{renderStatus(user.status)}</Table.Cell>
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,12 @@
 | 
			
		||||
export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 1, text: 'OpenAI', value: 1, color: 'green' },
 | 
			
		||||
  { key: 2, text: 'API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 3, text: 'Azure', value: 3, color: 'olive' },
 | 
			
		||||
  { key: 2, text: 'API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
 | 
			
		||||
  { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
 | 
			
		||||
  { key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
 | 
			
		||||
  { key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
 | 
			
		||||
  { key: 8, text: '自定义', value: 8, color: 'pink' }
 | 
			
		||||
  { key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
 | 
			
		||||
  { key: 10, text: 'AI Proxy', value: 10, color: 'purple' }
 | 
			
		||||
];
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								web/src/helpers/render.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								web/src/helpers/render.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
export function renderText(text, limit) {
 | 
			
		||||
  if (text.length > limit) {
 | 
			
		||||
    return text.slice(0, limit - 3) + '...';
 | 
			
		||||
  }
 | 
			
		||||
  return text;
 | 
			
		||||
}
 | 
			
		||||
@@ -46,6 +46,9 @@ const EditChannel = () => {
 | 
			
		||||
    if (localInputs.base_url.endsWith('/')) {
 | 
			
		||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = '2023-03-15-preview';
 | 
			
		||||
    }
 | 
			
		||||
    let res;
 | 
			
		||||
    if (isEdit) {
 | 
			
		||||
      res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
 | 
			
		||||
@@ -83,7 +86,9 @@ const EditChannel = () => {
 | 
			
		||||
            inputs.type === 3 && (
 | 
			
		||||
              <>
 | 
			
		||||
                <Message>
 | 
			
		||||
                  注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的 model 参数替换为你的部署名称(模型名称中的点会被剔除)。
 | 
			
		||||
                  注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的 model
 | 
			
		||||
                  参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
 | 
			
		||||
                                                                    href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
 | 
			
		||||
                </Message>
 | 
			
		||||
                <Form.Field>
 | 
			
		||||
                  <Form.Input
 | 
			
		||||
@@ -151,7 +156,7 @@ const EditChannel = () => {
 | 
			
		||||
                onChange={handleInputChange}
 | 
			
		||||
                value={inputs.key}
 | 
			
		||||
                autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              />
 | 
			
		||||
            </Form.Field>
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
@@ -164,7 +169,7 @@ const EditChannel = () => {
 | 
			
		||||
              />
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Button onClick={submit}>提交</Button>
 | 
			
		||||
          <Button positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
      </Segment>
 | 
			
		||||
    </>
 | 
			
		||||
 
 | 
			
		||||
@@ -111,7 +111,7 @@ const EditRedemption = () => {
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            </>
 | 
			
		||||
          }
 | 
			
		||||
          <Button onClick={submit}>提交</Button>
 | 
			
		||||
          <Button positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
      </Segment>
 | 
			
		||||
    </>
 | 
			
		||||
 
 | 
			
		||||
@@ -106,6 +106,34 @@ const EditToken = () => {
 | 
			
		||||
              required={!isEdit}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='过期时间'
 | 
			
		||||
              name='expired_time'
 | 
			
		||||
              placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={expired_time}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              type='datetime-local'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <div style={{ lineHeight: '40px' }}>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              setExpiredTime(0, 0, 0, 0);
 | 
			
		||||
            }}>永不过期</Button>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              setExpiredTime(1, 0, 0, 0);
 | 
			
		||||
            }}>一个月后过期</Button>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              setExpiredTime(0, 1, 0, 0);
 | 
			
		||||
            }}>一天后过期</Button>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              setExpiredTime(0, 0, 1, 0);
 | 
			
		||||
            }}>一小时后过期</Button>
 | 
			
		||||
            <Button type={'button'} onClick={() => {
 | 
			
		||||
              setExpiredTime(0, 0, 0, 1);
 | 
			
		||||
            }}>一分钟后过期</Button>
 | 
			
		||||
          </div>
 | 
			
		||||
          <Message>注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。</Message>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
@@ -119,36 +147,10 @@ const EditToken = () => {
 | 
			
		||||
              disabled={unlimited_quota}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Button type={'button'} style={{ marginBottom: '14px' }} onClick={() => {
 | 
			
		||||
          <Button type={'button'} onClick={() => {
 | 
			
		||||
            setUnlimitedQuota();
 | 
			
		||||
          }}>{unlimited_quota ? '取消无限额度' : '设置为无限额度'}</Button>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='过期时间'
 | 
			
		||||
              name='expired_time'
 | 
			
		||||
              placeholder={'请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={expired_time}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              type='datetime-local'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Button type={'button'} onClick={() => {
 | 
			
		||||
            setExpiredTime(0, 0, 0, 0);
 | 
			
		||||
          }}>永不过期</Button>
 | 
			
		||||
          <Button type={'button'} onClick={() => {
 | 
			
		||||
            setExpiredTime(1, 0, 0, 0);
 | 
			
		||||
          }}>一个月后过期</Button>
 | 
			
		||||
          <Button type={'button'} onClick={() => {
 | 
			
		||||
            setExpiredTime(0, 1, 0, 0);
 | 
			
		||||
          }}>一天后过期</Button>
 | 
			
		||||
          <Button type={'button'} onClick={() => {
 | 
			
		||||
            setExpiredTime(0, 0, 1, 0);
 | 
			
		||||
          }}>一小时后过期</Button>
 | 
			
		||||
          <Button type={'button'} onClick={() => {
 | 
			
		||||
            setExpiredTime(0, 0, 0, 1);
 | 
			
		||||
          }}>一分钟后过期</Button>
 | 
			
		||||
          <Button onClick={submit}>提交</Button>
 | 
			
		||||
          <Button positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
      </Segment>
 | 
			
		||||
    </>
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Grid, Header, Segment, Statistic } from 'semantic-ui-react';
 | 
			
		||||
import { API, showError, showSuccess } from '../../helpers';
 | 
			
		||||
import { API, showError, showInfo, showSuccess } from '../../helpers';
 | 
			
		||||
 | 
			
		||||
const TopUp = () => {
 | 
			
		||||
  const [redemptionCode, setRedemptionCode] = useState('');
 | 
			
		||||
@@ -9,6 +9,7 @@ const TopUp = () => {
 | 
			
		||||
 | 
			
		||||
  const topUp = async () => {
 | 
			
		||||
    if (redemptionCode === '') {
 | 
			
		||||
      showInfo('请输入充值码!')
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    const res = await API.post('/api/user/topup', {
 | 
			
		||||
@@ -80,7 +81,7 @@ const TopUp = () => {
 | 
			
		||||
        <Grid.Column>
 | 
			
		||||
          <Statistic.Group widths='one'>
 | 
			
		||||
            <Statistic>
 | 
			
		||||
              <Statistic.Value>{userQuota}</Statistic.Value>
 | 
			
		||||
              <Statistic.Value>{userQuota.toLocaleString()}</Statistic.Value>
 | 
			
		||||
              <Statistic.Label>剩余额度</Statistic.Label>
 | 
			
		||||
            </Statistic>
 | 
			
		||||
          </Statistic.Group>
 | 
			
		||||
 
 | 
			
		||||
@@ -65,7 +65,7 @@ const AddUser = () => {
 | 
			
		||||
              required
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Button type={'submit'} onClick={submit}>
 | 
			
		||||
          <Button positive type={'submit'} onClick={submit}>
 | 
			
		||||
            提交
 | 
			
		||||
          </Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
 
 | 
			
		||||
@@ -14,8 +14,9 @@ const EditUser = () => {
 | 
			
		||||
    github_id: '',
 | 
			
		||||
    wechat_id: '',
 | 
			
		||||
    email: '',
 | 
			
		||||
    quota: 0,
 | 
			
		||||
  });
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email } =
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email, quota } =
 | 
			
		||||
    inputs;
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
@@ -44,7 +45,11 @@ const EditUser = () => {
 | 
			
		||||
  const submit = async () => {
 | 
			
		||||
    let res = undefined;
 | 
			
		||||
    if (userId) {
 | 
			
		||||
      res = await API.put(`/api/user/`, { ...inputs, id: parseInt(userId) });
 | 
			
		||||
      let data = { ...inputs, id: parseInt(userId) };
 | 
			
		||||
      if (typeof data.quota === 'string') {
 | 
			
		||||
        data.quota = parseInt(data.quota);
 | 
			
		||||
      }
 | 
			
		||||
      res = await API.put(`/api/user/`, data);
 | 
			
		||||
    } else {
 | 
			
		||||
      res = await API.put(`/api/user/self`, inputs);
 | 
			
		||||
    }
 | 
			
		||||
@@ -92,6 +97,21 @@ const EditUser = () => {
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          {
 | 
			
		||||
            userId && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='剩余额度'
 | 
			
		||||
                  name='quota'
 | 
			
		||||
                  placeholder={'请输入新的剩余额度'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={quota}
 | 
			
		||||
                  type={'number'}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='已绑定的 GitHub 账户'
 | 
			
		||||
@@ -122,7 +142,7 @@ const EditUser = () => {
 | 
			
		||||
              readOnly
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Button onClick={submit}>提交</Button>
 | 
			
		||||
          <Button positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
      </Segment>
 | 
			
		||||
    </>
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user