mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			21 Commits
		
	
	
		
			v0.5.5-alp
			...
			v0.5.6-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					1d258cc898 | ||
| 
						 | 
					37e09d764c | ||
| 
						 | 
					159b9e3369 | ||
| 
						 | 
					92001986db | ||
| 
						 | 
					a5647b1ea7 | ||
| 
						 | 
					215e54fc96 | ||
| 
						 | 
					ecf8a6d875 | ||
| 
						 | 
					24df3e5f62 | ||
| 
						 | 
					12ef9679a7 | ||
| 
						 | 
					328aa68255 | ||
| 
						 | 
					4335f005a6 | ||
| 
						 | 
					fe26a1448d | ||
| 
						 | 
					42451d9d02 | ||
| 
						 | 
					25c4c111ab | ||
| 
						 | 
					0d50ad4b2b | ||
| 
						 | 
					959bcdef88 | ||
| 
						 | 
					39ae8075e4 | ||
| 
						 | 
					b57a0eca16 | ||
| 
						 | 
					1b4cc78890 | ||
| 
						 | 
					420c375140 | ||
| 
						 | 
					01863d3e44 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -4,4 +4,5 @@ upload
 | 
			
		||||
*.exe
 | 
			
		||||
*.db
 | 
			
		||||
build
 | 
			
		||||
*.db-journal
 | 
			
		||||
*.db-journal
 | 
			
		||||
logs
 | 
			
		||||
							
								
								
									
										19
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								README.md
									
									
									
									
									
								
							@@ -71,10 +71,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [360 智脑](https://ai.360.cn)
 | 
			
		||||
2. 支持配置镜像以及众多第三方代理服务:
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
   + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
 | 
			
		||||
   + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
 | 
			
		||||
   + [x] [CloseAI](https://console.closeai-asia.com/r/2412)
 | 
			
		||||
   + [x] 自定义渠道:例如各种未收录的第三方代理服务
 | 
			
		||||
3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
@@ -211,6 +211,13 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
 | 
			
		||||
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
 | 
			
		||||
 | 
			
		||||
#### QChatGPT - QQ机器人
 | 
			
		||||
项目主页:https://github.com/RockChinQ/QChatGPT
 | 
			
		||||
 | 
			
		||||
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key,并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
 | 
			
		||||
 | 
			
		||||
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
 | 
			
		||||
 | 
			
		||||
### 部署到第三方平台
 | 
			
		||||
<details>
 | 
			
		||||
<summary><strong>部署到 Sealos </strong></summary>
 | 
			
		||||
@@ -262,6 +269,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
 | 
			
		||||
 | 
			
		||||
注意,具体的 API Base 的格式取决于你所使用的客户端。
 | 
			
		||||
 | 
			
		||||
例如对于 OpenAI 的官方库:
 | 
			
		||||
```bash
 | 
			
		||||
OPENAI_API_KEY="sk-xxxxxx"
 | 
			
		||||
OPENAI_API_BASE="https://<HOST>:<PORT>/v1" 
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```mermaid
 | 
			
		||||
graph LR
 | 
			
		||||
    A(用户)
 | 
			
		||||
@@ -318,7 +331,7 @@ graph LR
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
   + 例子:`--port 3000`
 | 
			
		||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。
 | 
			
		||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
 | 
			
		||||
   + 例子:`--log-dir ./logs`
 | 
			
		||||
3. `--version`: 打印系统版本号并退出。
 | 
			
		||||
4. `--help`: 查看命令的使用帮助和参数说明。
 | 
			
		||||
@@ -364,4 +377,4 @@ https://openai.justsong.cn
 | 
			
		||||
 | 
			
		||||
同样适用于基于本项目的二开项目。
 | 
			
		||||
 | 
			
		||||
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
 | 
			
		||||
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。
 | 
			
		||||
 
 | 
			
		||||
@@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
 | 
			
		||||
var BatchUpdateEnabled = false
 | 
			
		||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RequestIdKey = "X-Oneapi-Request-Id"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ var (
 | 
			
		||||
	Port         = flag.Int("port", 3000, "the listening port")
 | 
			
		||||
	PrintVersion = flag.Bool("version", false, "print version and exit")
 | 
			
		||||
	PrintHelp    = flag.Bool("help", false, "print help and exit")
 | 
			
		||||
	LogDir       = flag.String("log-dir", "", "specify the log directory")
 | 
			
		||||
	LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func printHelp() {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,29 +1,47 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetupGinLog() {
 | 
			
		||||
const (
 | 
			
		||||
	loggerINFO  = "INFO"
 | 
			
		||||
	loggerWarn  = "WARN"
 | 
			
		||||
	loggerError = "ERR"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const maxLogCount = 1000000
 | 
			
		||||
 | 
			
		||||
var logCount int
 | 
			
		||||
var setupLogLock sync.Mutex
 | 
			
		||||
var setupLogWorking bool
 | 
			
		||||
 | 
			
		||||
func SetupLogger() {
 | 
			
		||||
	if *LogDir != "" {
 | 
			
		||||
		commonLogPath := filepath.Join(*LogDir, "common.log")
 | 
			
		||||
		errorLogPath := filepath.Join(*LogDir, "error.log")
 | 
			
		||||
		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		ok := setupLogLock.TryLock()
 | 
			
		||||
		if !ok {
 | 
			
		||||
			log.Println("setup log is already working")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer func() {
 | 
			
		||||
			setupLogLock.Unlock()
 | 
			
		||||
			setupLogWorking = false
 | 
			
		||||
		}()
 | 
			
		||||
		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
 | 
			
		||||
		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("failed to open log file")
 | 
			
		||||
		}
 | 
			
		||||
		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal("failed to open log file")
 | 
			
		||||
		}
 | 
			
		||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
 | 
			
		||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
 | 
			
		||||
		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
 | 
			
		||||
		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -37,6 +55,36 @@ func SysError(s string) {
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogInfo(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerINFO, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogWarn(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerWarn, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func LogError(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerError, msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logHelper(ctx context.Context, level string, msg string) {
 | 
			
		||||
	writer := gin.DefaultErrorWriter
 | 
			
		||||
	if level == loggerINFO {
 | 
			
		||||
		writer = gin.DefaultWriter
 | 
			
		||||
	}
 | 
			
		||||
	id := ctx.Value(RequestIdKey)
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
 | 
			
		||||
	logCount++ // we don't need accurate count, so no lock here
 | 
			
		||||
	if logCount > maxLogCount && !setupLogWorking {
 | 
			
		||||
		logCount = 0
 | 
			
		||||
		setupLogWorking = true
 | 
			
		||||
		go func() {
 | 
			
		||||
			SetupLogger()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FatalLog(v ...any) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 | 
			
		||||
 
 | 
			
		||||
@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
 | 
			
		||||
	return time.Now().Unix()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTimeString() string {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Max(a int, b int) int {
 | 
			
		||||
	if a >= b {
 | 
			
		||||
		return a
 | 
			
		||||
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MessageWithRequestId(message string, id string) string {
 | 
			
		||||
	return fmt.Sprintf("%s (request id: %s)", message, id)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		openAIError := OpenAIError{
 | 
			
		||||
			Message: err.Error(),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
			"error": openAIError,
 | 
			
		||||
 
 | 
			
		||||
@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
 | 
			
		||||
	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
 | 
			
		||||
	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.BaseURL == "" {
 | 
			
		||||
		channel.BaseURL = baseURL
 | 
			
		||||
	if channel.GetBaseURL() == "" {
 | 
			
		||||
		channel.BaseURL = &baseURL
 | 
			
		||||
	}
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypeOpenAI:
 | 
			
		||||
		if channel.BaseURL != "" {
 | 
			
		||||
			baseURL = channel.BaseURL
 | 
			
		||||
		if channel.GetBaseURL() != "" {
 | 
			
		||||
			baseURL = channel.GetBaseURL()
 | 
			
		||||
		}
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		return 0, errors.New("尚未实现")
 | 
			
		||||
	case common.ChannelTypeCustom:
 | 
			
		||||
		baseURL = channel.BaseURL
 | 
			
		||||
		baseURL = channel.GetBaseURL()
 | 
			
		||||
	case common.ChannelTypeCloseAI:
 | 
			
		||||
		return updateChannelCloseAIBalance(channel)
 | 
			
		||||
	case common.ChannelTypeOpenAISB:
 | 
			
		||||
 
 | 
			
		||||
@@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 | 
			
		||||
	}
 | 
			
		||||
	requestURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
 | 
			
		||||
		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
 | 
			
		||||
	} else {
 | 
			
		||||
		if channel.BaseURL != "" {
 | 
			
		||||
			requestURL = channel.BaseURL
 | 
			
		||||
		if channel.GetBaseURL() != "" {
 | 
			
		||||
			requestURL = channel.GetBaseURL()
 | 
			
		||||
		}
 | 
			
		||||
		requestURL += "/v1/chat/completions"
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -79,6 +79,14 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
 | 
			
		||||
 | 
			
		||||
func GitHubOAuth(c *gin.Context) {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	state := c.Query("state")
 | 
			
		||||
	if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
 | 
			
		||||
		c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "state is empty or not same",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	username := session.Get("username")
 | 
			
		||||
	if username != nil {
 | 
			
		||||
		GitHubBind(c)
 | 
			
		||||
@@ -205,3 +213,22 @@ func GitHubBind(c *gin.Context) {
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenerateOAuthCode(c *gin.Context) {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	state := common.GetRandomString(12)
 | 
			
		||||
	session.Set("oauth_state", state)
 | 
			
		||||
	err := session.Save()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    state,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
 | 
			
		||||
	username := c.Query("username")
 | 
			
		||||
	tokenName := c.Query("token_name")
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
			
		||||
	channel, _ := strconv.Atoi(c.Query("channel"))
 | 
			
		||||
	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserLogs(c *gin.Context) {
 | 
			
		||||
@@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchAllLogs(c *gin.Context) {
 | 
			
		||||
	keyword := c.Query("keyword")
 | 
			
		||||
	logs, err := model.SearchAllLogs(keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUserLogs(c *gin.Context) {
 | 
			
		||||
@@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	logs, err := model.SearchUserLogs(userId, keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(200, gin.H{
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    logs,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetLogsStat(c *gin.Context) {
 | 
			
		||||
@@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
 | 
			
		||||
	tokenName := c.Query("token_name")
 | 
			
		||||
	username := c.Query("username")
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
			
		||||
	channel, _ := strconv.Atoi(c.Query("channel"))
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 | 
			
		||||
	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data": gin.H{
 | 
			
		||||
@@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
 | 
			
		||||
			//"token": tokenNum,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
@@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
			
		||||
	tokenName := c.Query("token_name")
 | 
			
		||||
	modelName := c.Query("model_name")
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
			
		||||
	channel, _ := strconv.Atoi(c.Query("channel"))
 | 
			
		||||
	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 | 
			
		||||
	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data": gin.H{
 | 
			
		||||
@@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
			//"token": tokenNum,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteHistoryLogs(c *gin.Context) {
 | 
			
		||||
	targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
 | 
			
		||||
	if targetTimestamp == 0 {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "target timestamp is required",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	count, err := model.DeleteOldLog(targetTimestamp)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": err.Error(),
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    count,
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -17,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
 | 
			
		||||
@@ -91,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
	var audioResponse AudioResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		go func() {
 | 
			
		||||
			quota := countTokenText(audioResponse.Text, audioModel)
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
@@ -106,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
@@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse ImageResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quota)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -37,6 +38,7 @@ func init() {
 | 
			
		||||
 | 
			
		||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
@@ -210,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		// in this case, we do not pre-consume quota
 | 
			
		||||
		// because the user has enough quota
 | 
			
		||||
		preConsumedQuota = 0
 | 
			
		||||
		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 | 
			
		||||
	}
 | 
			
		||||
	if consumeQuota && preConsumedQuota > 0 {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
@@ -347,15 +350,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
			if preConsumedQuota != 0 {
 | 
			
		||||
				go func(ctx context.Context) {
 | 
			
		||||
					// return pre-consumed quota
 | 
			
		||||
					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				}(c.Request.Context())
 | 
			
		||||
			}
 | 
			
		||||
			return relayErrorHandler(resp)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	tokenName := c.GetString("token_name")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
	defer func(ctx context.Context) {
 | 
			
		||||
		// c.Writer.Flush()
 | 
			
		||||
		go func() {
 | 
			
		||||
			if consumeQuota {
 | 
			
		||||
@@ -378,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
				quotaDelta := quota - preConsumedQuota
 | 
			
		||||
				err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
					common.LogError(ctx, "error update user quota cache: "+err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				if quota != 0 {
 | 
			
		||||
					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
					model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if isStream {
 | 
			
		||||
@@ -530,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeXunfei:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			auth := c.Request.Header.Get("Authorization")
 | 
			
		||||
			auth = strings.TrimPrefix(auth, "Bearer ")
 | 
			
		||||
			splits := strings.Split(auth, "|")
 | 
			
		||||
			if len(splits) != 3 {
 | 
			
		||||
				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
			
		||||
			}
 | 
			
		||||
			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
		auth := c.Request.Header.Get("Authorization")
 | 
			
		||||
		auth = strings.TrimPrefix(auth, "Bearer ")
 | 
			
		||||
		splits := strings.Split(auth, "|")
 | 
			
		||||
		if len(splits) != 3 {
 | 
			
		||||
			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
		var err *OpenAIErrorWithStatusCode
 | 
			
		||||
		var usage *Usage
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if usage != nil {
 | 
			
		||||
			textResponse.Usage = *usage
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := aiProxyLibraryStreamHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -146,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 | 
			
		||||
		StatusCode: resp.StatusCode,
 | 
			
		||||
		OpenAIError: OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
			Type:    "upstream_error",
 | 
			
		||||
			Code:    "bad_response_status_code",
 | 
			
		||||
			Param:   strconv.Itoa(resp.StatusCode),
 | 
			
		||||
		},
 | 
			
		||||
 
 | 
			
		||||
@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: response.Payload.Choices.Text[0].Content,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: stopFinishReason,
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
			
		||||
	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse := <-dataChan:
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
			
		||||
			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
 | 
			
		||||
	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = "v1.1"
 | 
			
		||||
		common.SysLog("api_version not found, use default: " + apiVersion)
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	var content string
 | 
			
		||||
	var xunfeiResponse XunfeiChatResponse
 | 
			
		||||
	stop := false
 | 
			
		||||
	for !stop {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse = <-dataChan:
 | 
			
		||||
			content += xunfeiResponse.Payload.Choices.Text[0].Content
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
			
		||||
		case stop = <-stopChan:
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	domain := "general"
 | 
			
		||||
	if apiVersion == "v2.1" {
 | 
			
		||||
		domain = "generalv2"
 | 
			
		||||
 | 
			
		||||
	xunfeiResponse.Payload.Choices.Text[0].Content = content
 | 
			
		||||
 | 
			
		||||
	response := responseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	_, _ = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
 | 
			
		||||
	conn, resp, err := d.Dial(authUrl, nil)
 | 
			
		||||
	if err != nil || resp.StatusCode != 101 {
 | 
			
		||||
		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 | 
			
		||||
	err = conn.WriteJSON(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dataChan := make(chan XunfeiChatResponse)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse := <-dataChan:
 | 
			
		||||
			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
 | 
			
		||||
			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
 | 
			
		||||
	return dataChan, stopChan, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var xunfeiResponse XunfeiChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = "v1.1"
 | 
			
		||||
		common.SysLog("api_version not found, use default: " + apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &xunfeiResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	domain := "general"
 | 
			
		||||
	if apiVersion == "v2.1" {
 | 
			
		||||
		domain = "generalv2"
 | 
			
		||||
	}
 | 
			
		||||
	if xunfeiResponse.Header.Code != 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: xunfeiResponse.Header.Message,
 | 
			
		||||
				Type:    "xunfei_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    xunfeiResponse.Header.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 | 
			
		||||
	return domain, authUrl
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -196,6 +196,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		err = relayTextHelper(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		requestId := c.GetString(common.RequestIdKey)
 | 
			
		||||
		retryTimesStr := c.Query("retry")
 | 
			
		||||
		retryTimes, _ := strconv.Atoi(retryTimesStr)
 | 
			
		||||
		if retryTimesStr == "" {
 | 
			
		||||
@@ -207,12 +208,13 @@ func Relay(c *gin.Context) {
 | 
			
		||||
			if err.StatusCode == http.StatusTooManyRequests {
 | 
			
		||||
				err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 | 
			
		||||
			}
 | 
			
		||||
			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 | 
			
		||||
			c.JSON(err.StatusCode, gin.H{
 | 
			
		||||
				"error": err.OpenAIError,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										9
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								main.go
									
									
									
									
									
								
							@@ -21,7 +21,7 @@ var buildFS embed.FS
 | 
			
		||||
var indexPage []byte
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	common.SetupGinLog()
 | 
			
		||||
	common.SetupLogger()
 | 
			
		||||
	common.SysLog("One API " + common.Version + " started")
 | 
			
		||||
	if os.Getenv("GIN_MODE") != "debug" {
 | 
			
		||||
		gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
@@ -85,11 +85,12 @@ func main() {
 | 
			
		||||
	controller.InitTokenEncoders()
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
	server := gin.New()
 | 
			
		||||
	server.Use(gin.Recovery())
 | 
			
		||||
	// This will cause SSE not to work!!!
 | 
			
		||||
	//server.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	server.Use(middleware.CORS())
 | 
			
		||||
 | 
			
		||||
	server.Use(middleware.RequestId())
 | 
			
		||||
	middleware.SetUpLogger(server)
 | 
			
		||||
	// Initialize session store
 | 
			
		||||
	store := cookie.NewStore([]byte(common.SessionSecret))
 | 
			
		||||
	server.Use(sessions.Sessions("session", store))
 | 
			
		||||
 
 | 
			
		||||
@@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
		key = parts[0]
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusUnauthorized, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		userEnabled, err := model.IsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled {
 | 
			
		||||
			c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": "用户已被封禁",
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("id", token.UserId)
 | 
			
		||||
@@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			if model.IsAdmin(token.UserId) {
 | 
			
		||||
				c.Set("channelId", parts[1])
 | 
			
		||||
			} else {
 | 
			
		||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "普通用户不支持指定渠道",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的渠道 ID",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.GetChannelById(id, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的渠道 ID",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "该渠道已被禁用",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": "无效的请求",
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
@@ -99,22 +75,16 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
			
		||||
					message = "数据库一致性已被破坏,请联系管理员"
 | 
			
		||||
				}
 | 
			
		||||
				c.JSON(http.StatusServiceUnavailable, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
						"message": message,
 | 
			
		||||
						"type":    "one_api_error",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				abortWithMessage(c, http.StatusServiceUnavailable, message)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("channel", channel.Type)
 | 
			
		||||
		c.Set("channel_id", channel.Id)
 | 
			
		||||
		c.Set("channel_name", channel.Name)
 | 
			
		||||
		c.Set("model_mapping", channel.ModelMapping)
 | 
			
		||||
		c.Set("model_mapping", channel.GetModelMapping())
 | 
			
		||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
		c.Set("base_url", channel.BaseURL)
 | 
			
		||||
		c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
		switch channel.Type {
 | 
			
		||||
		case common.ChannelTypeAzure:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								middleware/logger.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetUpLogger(server *gin.Engine) {
 | 
			
		||||
	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 | 
			
		||||
		var requestID string
 | 
			
		||||
		if param.Keys != nil {
 | 
			
		||||
			requestID = param.Keys[common.RequestIdKey].(string)
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
 | 
			
		||||
			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
 | 
			
		||||
			requestID,
 | 
			
		||||
			param.StatusCode,
 | 
			
		||||
			param.Latency,
 | 
			
		||||
			param.ClientIP,
 | 
			
		||||
			param.Method,
 | 
			
		||||
			param.Path,
 | 
			
		||||
		)
 | 
			
		||||
	}))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								middleware/request-id.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func RequestId() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		id := common.GetTimeString() + common.GetRandomString(8)
 | 
			
		||||
		c.Set(common.RequestIdKey, id)
 | 
			
		||||
		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
 | 
			
		||||
		c.Request = c.Request.WithContext(ctx)
 | 
			
		||||
		c.Header(common.RequestIdKey, id)
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								middleware/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package middleware
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
	c.JSON(statusCode, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
 | 
			
		||||
			"type":    "one_api_error",
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	c.Abort()
 | 
			
		||||
	common.LogError(c.Request.Context(), message)
 | 
			
		||||
}
 | 
			
		||||
@@ -10,15 +10,18 @@ type Ability struct {
 | 
			
		||||
	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 | 
			
		||||
	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 | 
			
		||||
	Enabled   bool   `json:"enabled"`
 | 
			
		||||
	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 | 
			
		||||
	ability := Ability{}
 | 
			
		||||
	var err error = nil
 | 
			
		||||
	if common.UsingSQLite {
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
 | 
			
		||||
		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RANDOM()").Limit(1).First(&ability).Error
 | 
			
		||||
	} else {
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
 | 
			
		||||
		maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("group = ? and model = ? and enabled = 1", group, model)
 | 
			
		||||
		err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RAND()").Limit(1).First(&ability).Error
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -40,6 +43,7 @@ func (channel *Channel) AddAbilities() error {
 | 
			
		||||
				Model:     model,
 | 
			
		||||
				ChannelId: channel.Id,
 | 
			
		||||
				Enabled:   channel.Status == common.ChannelStatusEnabled,
 | 
			
		||||
				Priority:  channel.Priority,
 | 
			
		||||
			}
 | 
			
		||||
			abilities = append(abilities, ability)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
@@ -159,6 +160,17 @@ func InitChannelCache() {
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// sort by priority
 | 
			
		||||
	for group, model2channels := range newGroup2model2channels {
 | 
			
		||||
		for model, channels := range model2channels {
 | 
			
		||||
			sort.Slice(channels, func(i, j int) bool {
 | 
			
		||||
				return channels[i].GetPriority() > channels[j].GetPriority()
 | 
			
		||||
			})
 | 
			
		||||
			newGroup2model2channels[group][model] = channels
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	channelSyncLock.Lock()
 | 
			
		||||
	group2model2channels = newGroup2model2channels
 | 
			
		||||
	channelSyncLock.Unlock()
 | 
			
		||||
@@ -183,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 | 
			
		||||
	if len(channels) == 0 {
 | 
			
		||||
		return nil, errors.New("channel not found")
 | 
			
		||||
	}
 | 
			
		||||
	idx := rand.Intn(len(channels))
 | 
			
		||||
	endIdx := len(channels)
 | 
			
		||||
	// choose by priority
 | 
			
		||||
	firstChannel := channels[0]
 | 
			
		||||
	if firstChannel.GetPriority() > 0 {
 | 
			
		||||
		for i := range channels {
 | 
			
		||||
			if channels[i].GetPriority() != firstChannel.GetPriority() {
 | 
			
		||||
				endIdx = i
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	idx := rand.Intn(endIdx)
 | 
			
		||||
	return channels[idx], nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -15,14 +15,15 @@ type Channel struct {
 | 
			
		||||
	CreatedTime        int64   `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	TestTime           int64   `json:"test_time" gorm:"bigint"`
 | 
			
		||||
	ResponseTime       int     `json:"response_time"` // in milliseconds
 | 
			
		||||
	BaseURL            string  `json:"base_url" gorm:"column:base_url"`
 | 
			
		||||
	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"`
 | 
			
		||||
	Other              string  `json:"other"`
 | 
			
		||||
	Balance            float64 `json:"balance"` // in USD
 | 
			
		||||
	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
 | 
			
		||||
	Models             string  `json:"models"`
 | 
			
		||||
	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
			
		||||
	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 | 
			
		||||
	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
			
		||||
	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 | 
			
		||||
	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 | 
			
		||||
@@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) GetPriority() int64 {
 | 
			
		||||
	if channel.Priority == nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return *channel.Priority
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) GetBaseURL() string {
 | 
			
		||||
	if channel.BaseURL == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return *channel.BaseURL
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) GetModelMapping() string {
 | 
			
		||||
	if channel.ModelMapping == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return *channel.ModelMapping
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (channel *Channel) Insert() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	err = DB.Create(channel).Error
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										25
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								model/log.go
									
									
									
									
									
								
							@@ -1,6 +1,8 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
@@ -17,6 +19,7 @@ type Log struct {
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"default:0"`
 | 
			
		||||
	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 | 
			
		||||
	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
 | 
			
		||||
	Channel          int    `json:"channel" gorm:"default:0"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -44,7 +47,9 @@ func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
			
		||||
	if !common.LogConsumeEnabled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -59,14 +64,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
 | 
			
		||||
		TokenName:        tokenName,
 | 
			
		||||
		ModelName:        modelName,
 | 
			
		||||
		Quota:            quota,
 | 
			
		||||
		Channel:          channelId,
 | 
			
		||||
	}
 | 
			
		||||
	err := DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to record log: " + err.Error())
 | 
			
		||||
		common.LogError(ctx, "failed to record log: "+err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
 | 
			
		||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 | 
			
		||||
	var tx *gorm.DB
 | 
			
		||||
	if logType == LogTypeUnknown {
 | 
			
		||||
		tx = DB
 | 
			
		||||
@@ -88,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 | 
			
		||||
	if endTimestamp != 0 {
 | 
			
		||||
		tx = tx.Where("created_at <= ?", endTimestamp)
 | 
			
		||||
	}
 | 
			
		||||
	if channel != 0 {
 | 
			
		||||
		tx = tx.Where("channel = ?", channel)
 | 
			
		||||
	}
 | 
			
		||||
	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
@@ -125,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
 | 
			
		||||
	tx := DB.Table("logs").Select("sum(quota)")
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
@@ -142,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
	if modelName != "" {
 | 
			
		||||
		tx = tx.Where("model_name = ?", modelName)
 | 
			
		||||
	}
 | 
			
		||||
	if channel != 0 {
 | 
			
		||||
		tx = tx.Where("channel = ?", channel)
 | 
			
		||||
	}
 | 
			
		||||
	tx.Where("type = ?", LogTypeConsume).Scan("a)
 | 
			
		||||
	return quota
 | 
			
		||||
}
 | 
			
		||||
@@ -166,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
	tx.Where("type = ?", LogTypeConsume).Scan(&token)
 | 
			
		||||
	return token
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
 | 
			
		||||
	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
 | 
			
		||||
	return result.RowsAffected, result.Error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
 | 
			
		||||
		apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
 | 
			
		||||
		apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
 | 
			
		||||
		apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
 | 
			
		||||
		apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
 | 
			
		||||
		apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
 | 
			
		||||
@@ -97,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		}
 | 
			
		||||
		logRoute := apiRouter.Group("/log")
 | 
			
		||||
		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
 | 
			
		||||
		logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
 | 
			
		||||
		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
 | 
			
		||||
		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
 | 
			
		||||
		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
	router.Use(middleware.CORS())
 | 
			
		||||
	// https://platform.openai.com/docs/api-reference/introduction
 | 
			
		||||
	modelsRouter := router.Group("/v1/models")
 | 
			
		||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
 | 
			
		||||
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
 | 
			
		||||
import { Link } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 | 
			
		||||
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
 | 
			
		||||
 | 
			
		||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 | 
			
		||||
import { renderGroup, renderNumber } from '../helpers/render';
 | 
			
		||||
@@ -24,7 +24,7 @@ function renderType(type) {
 | 
			
		||||
    }
 | 
			
		||||
    type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
 | 
			
		||||
  }
 | 
			
		||||
  return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
 | 
			
		||||
  return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function renderBalance(type, balance) {
 | 
			
		||||
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
 | 
			
		||||
      });
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const manageChannel = async (id, action, idx) => {
 | 
			
		||||
  const manageChannel = async (id, action, idx, priority) => {
 | 
			
		||||
    let data = { id };
 | 
			
		||||
    let res;
 | 
			
		||||
    switch (action) {
 | 
			
		||||
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
 | 
			
		||||
        data.status = 2;
 | 
			
		||||
        res = await API.put('/api/channel/', data);
 | 
			
		||||
        break;
 | 
			
		||||
      case 'priority':
 | 
			
		||||
        if (priority === '') {
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
        data.priority = parseInt(priority);
 | 
			
		||||
        res = await API.put('/api/channel/', data);
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
    const { success, message } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
@@ -195,6 +202,7 @@ const ChannelsTable = () => {
 | 
			
		||||
      showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
 | 
			
		||||
    } else {
 | 
			
		||||
      showError(message);
 | 
			
		||||
      showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
@@ -334,6 +342,14 @@ const ChannelsTable = () => {
 | 
			
		||||
            >
 | 
			
		||||
              余额
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell
 | 
			
		||||
                style={{ cursor: 'pointer' }}
 | 
			
		||||
                onClick={() => {
 | 
			
		||||
                  sortChannel('priority');
 | 
			
		||||
                }}
 | 
			
		||||
            >
 | 
			
		||||
              优先级
 | 
			
		||||
            </Table.HeaderCell>
 | 
			
		||||
            <Table.HeaderCell>操作</Table.HeaderCell>
 | 
			
		||||
          </Table.Row>
 | 
			
		||||
        </Table.Header>
 | 
			
		||||
@@ -372,6 +388,22 @@ const ChannelsTable = () => {
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <Popup
 | 
			
		||||
                        trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => {
 | 
			
		||||
                          manageChannel(
 | 
			
		||||
                              channel.id,
 | 
			
		||||
                              'priority',
 | 
			
		||||
                              idx,
 | 
			
		||||
                              event.target.value,
 | 
			
		||||
                          );
 | 
			
		||||
                        }}>
 | 
			
		||||
                          <input style={{maxWidth:'60px'}} />
 | 
			
		||||
                        </Input>}
 | 
			
		||||
                        content='渠道选择优先级,越高越优先'
 | 
			
		||||
                        basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
                  <Table.Cell>
 | 
			
		||||
                    <div>
 | 
			
		||||
                      <Button
 | 
			
		||||
@@ -440,7 +472,7 @@ const ChannelsTable = () => {
 | 
			
		||||
 | 
			
		||||
        <Table.Footer>
 | 
			
		||||
          <Table.Row>
 | 
			
		||||
            <Table.HeaderCell colSpan='8'>
 | 
			
		||||
            <Table.HeaderCell colSpan='9'>
 | 
			
		||||
              <Button size='small' as={Link} to='/channel/add' loading={loading}>
 | 
			
		||||
                添加新的渠道
 | 
			
		||||
              </Button>
 | 
			
		||||
 
 | 
			
		||||
@@ -13,8 +13,8 @@ const GitHubOAuth = () => {
 | 
			
		||||
 | 
			
		||||
  let navigate = useNavigate();
 | 
			
		||||
 | 
			
		||||
  const sendCode = async (code, count) => {
 | 
			
		||||
    const res = await API.get(`/api/oauth/github?code=${code}`);
 | 
			
		||||
  const sendCode = async (code, state, count) => {
 | 
			
		||||
    const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      if (message === 'bind') {
 | 
			
		||||
@@ -36,13 +36,14 @@ const GitHubOAuth = () => {
 | 
			
		||||
      count++;
 | 
			
		||||
      setPrompt(`出现错误,第 ${count} 次重试中...`);
 | 
			
		||||
      await new Promise((resolve) => setTimeout(resolve, count * 2000));
 | 
			
		||||
      await sendCode(code, count);
 | 
			
		||||
      await sendCode(code, state, count);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    let code = searchParams.get('code');
 | 
			
		||||
    sendCode(code, 0).then();
 | 
			
		||||
    let state = searchParams.get('state');
 | 
			
		||||
    sendCode(code, state, 0).then();
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
 | 
			
		||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
import { API, getLogo, showError, showSuccess } from '../helpers';
 | 
			
		||||
import { getOAuthState, onGitHubOAuthClicked } from './utils';
 | 
			
		||||
 | 
			
		||||
const LoginForm = () => {
 | 
			
		||||
  const [inputs, setInputs] = useState({
 | 
			
		||||
@@ -31,12 +32,6 @@ const LoginForm = () => {
 | 
			
		||||
 | 
			
		||||
  const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
 | 
			
		||||
 | 
			
		||||
  const onGitHubOAuthClicked = () => {
 | 
			
		||||
    window.open(
 | 
			
		||||
      `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
 | 
			
		||||
    );
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const onWeChatLoginClicked = () => {
 | 
			
		||||
    setShowWeChatLoginModal(true);
 | 
			
		||||
  };
 | 
			
		||||
@@ -131,7 +126,7 @@ const LoginForm = () => {
 | 
			
		||||
                circular
 | 
			
		||||
                color='black'
 | 
			
		||||
                icon='github'
 | 
			
		||||
                onClick={onGitHubOAuthClicked}
 | 
			
		||||
                onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
 | 
			
		||||
              />
 | 
			
		||||
            ) : (
 | 
			
		||||
              <></>
 | 
			
		||||
 
 | 
			
		||||
@@ -56,9 +56,10 @@ const LogsTable = () => {
 | 
			
		||||
    token_name: '',
 | 
			
		||||
    model_name: '',
 | 
			
		||||
    start_timestamp: timestamp2string(0),
 | 
			
		||||
    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
 | 
			
		||||
    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
 | 
			
		||||
    channel: ''
 | 
			
		||||
  });
 | 
			
		||||
  const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
 | 
			
		||||
  const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
 | 
			
		||||
 | 
			
		||||
  const [stat, setStat] = useState({
 | 
			
		||||
    quota: 0,
 | 
			
		||||
@@ -84,7 +85,7 @@ const LogsTable = () => {
 | 
			
		||||
  const getLogStat = async () => {
 | 
			
		||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
			
		||||
    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
			
		||||
    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
 | 
			
		||||
    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      setStat(data);
 | 
			
		||||
@@ -109,7 +110,7 @@ const LogsTable = () => {
 | 
			
		||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
			
		||||
    let localEndTimestamp = Date.parse(end_timestamp) / 1000;
 | 
			
		||||
    if (isAdminUser) {
 | 
			
		||||
      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
			
		||||
      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
 | 
			
		||||
    } else {
 | 
			
		||||
      url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
 | 
			
		||||
    }
 | 
			
		||||
@@ -205,16 +206,9 @@ const LogsTable = () => {
 | 
			
		||||
        </Header>
 | 
			
		||||
        <Form>
 | 
			
		||||
          <Form.Group>
 | 
			
		||||
            {
 | 
			
		||||
              isAdminUser && (
 | 
			
		||||
                <Form.Input fluid label={'用户名称'} width={2} value={username}
 | 
			
		||||
                            placeholder={'可选值'} name='username'
 | 
			
		||||
                            onChange={handleInputChange} />
 | 
			
		||||
              )
 | 
			
		||||
            }
 | 
			
		||||
            <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
 | 
			
		||||
            <Form.Input fluid label={'令牌名称'} width={3} value={token_name}
 | 
			
		||||
                        placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
 | 
			
		||||
            <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
 | 
			
		||||
            <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
 | 
			
		||||
                        name='model_name'
 | 
			
		||||
                        onChange={handleInputChange} />
 | 
			
		||||
            <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
 | 
			
		||||
@@ -225,6 +219,19 @@ const LogsTable = () => {
 | 
			
		||||
                        onChange={handleInputChange} />
 | 
			
		||||
            <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          {
 | 
			
		||||
            isAdminUser && <>
 | 
			
		||||
              <Form.Group>
 | 
			
		||||
                <Form.Input fluid label={'渠道 ID'} width={3} value={channel}
 | 
			
		||||
                            placeholder='可选值' name='channel'
 | 
			
		||||
                            onChange={handleInputChange} />
 | 
			
		||||
                <Form.Input fluid label={'用户名称'} width={3} value={username}
 | 
			
		||||
                            placeholder={'可选值'} name='username'
 | 
			
		||||
                            onChange={handleInputChange} />
 | 
			
		||||
 | 
			
		||||
              </Form.Group>
 | 
			
		||||
            </>
 | 
			
		||||
          }
 | 
			
		||||
        </Form>
 | 
			
		||||
        <Table basic compact size='small'>
 | 
			
		||||
          <Table.Header>
 | 
			
		||||
@@ -238,6 +245,17 @@ const LogsTable = () => {
 | 
			
		||||
              >
 | 
			
		||||
                时间
 | 
			
		||||
              </Table.HeaderCell>
 | 
			
		||||
              {
 | 
			
		||||
                isAdminUser && <Table.HeaderCell
 | 
			
		||||
                  style={{ cursor: 'pointer' }}
 | 
			
		||||
                  onClick={() => {
 | 
			
		||||
                    sortLog('channel');
 | 
			
		||||
                  }}
 | 
			
		||||
                  width={1}
 | 
			
		||||
                >
 | 
			
		||||
                  渠道
 | 
			
		||||
                </Table.HeaderCell>
 | 
			
		||||
              }
 | 
			
		||||
              {
 | 
			
		||||
                isAdminUser && <Table.HeaderCell
 | 
			
		||||
                  style={{ cursor: 'pointer' }}
 | 
			
		||||
@@ -299,16 +317,16 @@ const LogsTable = () => {
 | 
			
		||||
                onClick={() => {
 | 
			
		||||
                  sortLog('quota');
 | 
			
		||||
                }}
 | 
			
		||||
                width={2}
 | 
			
		||||
                width={1}
 | 
			
		||||
              >
 | 
			
		||||
                消耗额度
 | 
			
		||||
                额度
 | 
			
		||||
              </Table.HeaderCell>
 | 
			
		||||
              <Table.HeaderCell
 | 
			
		||||
                style={{ cursor: 'pointer' }}
 | 
			
		||||
                onClick={() => {
 | 
			
		||||
                  sortLog('content');
 | 
			
		||||
                }}
 | 
			
		||||
                width={isAdminUser ? 4 : 5}
 | 
			
		||||
                width={isAdminUser ? 4 : 6}
 | 
			
		||||
              >
 | 
			
		||||
                详情
 | 
			
		||||
              </Table.HeaderCell>
 | 
			
		||||
@@ -326,6 +344,11 @@ const LogsTable = () => {
 | 
			
		||||
                return (
 | 
			
		||||
                  <Table.Row key={log.id}>
 | 
			
		||||
                    <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
 | 
			
		||||
                    {
 | 
			
		||||
                      isAdminUser && (
 | 
			
		||||
                        <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
 | 
			
		||||
                      )
 | 
			
		||||
                    }
 | 
			
		||||
                    {
 | 
			
		||||
                      isAdminUser && (
 | 
			
		||||
                        <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
 | 
			
		||||
@@ -345,7 +368,7 @@ const LogsTable = () => {
 | 
			
		||||
 | 
			
		||||
          <Table.Footer>
 | 
			
		||||
            <Table.Row>
 | 
			
		||||
              <Table.HeaderCell colSpan={'9'}>
 | 
			
		||||
              <Table.HeaderCell colSpan={'10'}>
 | 
			
		||||
                <Select
 | 
			
		||||
                  placeholder='选择明细分类'
 | 
			
		||||
                  options={LOG_OPTIONS}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,9 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
 | 
			
		||||
import { API, showError, verifyJSON } from '../helpers';
 | 
			
		||||
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
 | 
			
		||||
 | 
			
		||||
const OperationSetting = () => {
 | 
			
		||||
  let now = new Date();
 | 
			
		||||
  let [inputs, setInputs] = useState({
 | 
			
		||||
    QuotaForNewUser: 0,
 | 
			
		||||
    QuotaForInviter: 0,
 | 
			
		||||
@@ -20,10 +21,11 @@ const OperationSetting = () => {
 | 
			
		||||
    DisplayInCurrencyEnabled: '',
 | 
			
		||||
    DisplayTokenStatEnabled: '',
 | 
			
		||||
    ApproximateTokenEnabled: '',
 | 
			
		||||
    RetryTimes: 0,
 | 
			
		||||
    RetryTimes: 0
 | 
			
		||||
  });
 | 
			
		||||
  const [originInputs, setOriginInputs] = useState({});
 | 
			
		||||
  let [loading, setLoading] = useState(false);
 | 
			
		||||
  let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
 | 
			
		||||
 | 
			
		||||
  const getOptions = async () => {
 | 
			
		||||
    const res = await API.get('/api/option/');
 | 
			
		||||
@@ -130,6 +132,17 @@ const OperationSetting = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const deleteHistoryLogs = async () => {
 | 
			
		||||
    console.log(inputs);
 | 
			
		||||
    const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
 | 
			
		||||
    const { success, message, data } = res.data;
 | 
			
		||||
    if (success) {
 | 
			
		||||
      showSuccess(`${data} 条日志已清理!`);
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    showError('日志清理失败:' + message);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <Grid columns={1}>
 | 
			
		||||
      <Grid.Column>
 | 
			
		||||
@@ -179,12 +192,6 @@ const OperationSetting = () => {
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.LogConsumeEnabled === 'true'}
 | 
			
		||||
              label='启用额度消费日志记录'
 | 
			
		||||
              name='LogConsumeEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.DisplayInCurrencyEnabled === 'true'}
 | 
			
		||||
              label='以货币形式显示额度'
 | 
			
		||||
@@ -208,6 +215,28 @@ const OperationSetting = () => {
 | 
			
		||||
            submitConfig('general').then();
 | 
			
		||||
          }}>保存通用设置</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            日志设置
 | 
			
		||||
          </Header>
 | 
			
		||||
          <Form.Group inline>
 | 
			
		||||
            <Form.Checkbox
 | 
			
		||||
              checked={inputs.LogConsumeEnabled === 'true'}
 | 
			
		||||
              label='启用额度消费日志记录'
 | 
			
		||||
              name='LogConsumeEnabled'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group widths={4}>
 | 
			
		||||
            <Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
 | 
			
		||||
                        name='history_timestamp'
 | 
			
		||||
                        onChange={(e, { name, value }) => {
 | 
			
		||||
                          setHistoryTimestamp(value);
 | 
			
		||||
                        }} />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={() => {
 | 
			
		||||
            deleteHistoryLogs().then();
 | 
			
		||||
          }}>清理历史日志</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
            监控设置
 | 
			
		||||
          </Header>
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
 | 
			
		||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
 | 
			
		||||
import Turnstile from 'react-turnstile';
 | 
			
		||||
import { UserContext } from '../context/User';
 | 
			
		||||
import { onGitHubOAuthClicked } from './utils';
 | 
			
		||||
 | 
			
		||||
const PersonalSetting = () => {
 | 
			
		||||
  const [userState, userDispatch] = useContext(UserContext);
 | 
			
		||||
@@ -130,12 +131,6 @@ const PersonalSetting = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const openGitHubOAuth = () => {
 | 
			
		||||
    window.open(
 | 
			
		||||
      `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
 | 
			
		||||
    );
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const sendVerificationCode = async () => {
 | 
			
		||||
    setDisableButton(true);
 | 
			
		||||
    if (inputs.email === '') return;
 | 
			
		||||
@@ -249,7 +244,7 @@ const PersonalSetting = () => {
 | 
			
		||||
      </Modal>
 | 
			
		||||
      {
 | 
			
		||||
        status.github_oauth && (
 | 
			
		||||
          <Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
 | 
			
		||||
          <Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
 | 
			
		||||
        )
 | 
			
		||||
      }
 | 
			
		||||
      <Button
 | 
			
		||||
 
 | 
			
		||||
@@ -96,7 +96,7 @@ const TokensTable = () => {
 | 
			
		||||
    let nextUrl;
 | 
			
		||||
  
 | 
			
		||||
    if (nextLink) {
 | 
			
		||||
      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
 | 
			
		||||
      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
    } else {
 | 
			
		||||
      nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								web/src/components/utils.js
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
import { API, showError } from '../helpers';
 | 
			
		||||
 | 
			
		||||
export async function getOAuthState() {
 | 
			
		||||
  const res = await API.get('/api/oauth/state');
 | 
			
		||||
  const { success, message, data } = res.data;
 | 
			
		||||
  if (success) {
 | 
			
		||||
    return data;
 | 
			
		||||
  } else {
 | 
			
		||||
    showError(message);
 | 
			
		||||
    return '';
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export async function onGitHubOAuthClicked(github_client_id) {
 | 
			
		||||
  const state = await getOAuthState();
 | 
			
		||||
  if (!state) return;
 | 
			
		||||
  window.open(
 | 
			
		||||
    `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
@@ -174,7 +174,7 @@ const EditChannel = () => {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
    let localInputs = inputs;
 | 
			
		||||
    if (localInputs.base_url.endsWith('/')) {
 | 
			
		||||
    if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
 | 
			
		||||
      localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
			
		||||
@@ -183,9 +183,6 @@ const EditChannel = () => {
 | 
			
		||||
    if (localInputs.type === 18 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = 'v2.1';
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.model_mapping === '') {
 | 
			
		||||
      localInputs.model_mapping = '{}';
 | 
			
		||||
    }
 | 
			
		||||
    let res;
 | 
			
		||||
    localInputs.models = localInputs.models.join(',');
 | 
			
		||||
    localInputs.group = localInputs.groups.join(',');
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user