mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	Compare commits
	
		
			28 Commits
		
	
	
		
			v0.6.2-alp
			...
			v0.6.2-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 7cd57f3125 | ||
|  | 66efabd5ae | ||
|  | 8ede66a896 | ||
|  | b169173860 | ||
|  | f33555ae78 | ||
|  | c28ec10795 | ||
|  | e3767cbb07 | ||
|  | be9eb59fbb | ||
|  | 89e111ac69 | ||
|  | 2dcef85285 | ||
|  | 79d0cd378a | ||
|  | e99150bdb9 | ||
|  | a72e5fcc9e | ||
|  | 0710f8cd66 | ||
|  | 49cad7d4a5 | ||
|  | a90161cf00 | ||
|  | a45fc7d736 | ||
|  | 45940dcb12 | ||
|  | 969042b001 | ||
|  | 7e7369dbc4 | ||
|  | e54e647170 | ||
|  | 358920c858 | ||
|  | 1ea598c773 | ||
|  | 796be42487 | ||
|  | 5b50eb94e5 | ||
|  | 71c61365eb | ||
|  | b09f979b80 | ||
|  | 12440874b0 | 
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,13 @@ jobs: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi       | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|   | ||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,13 @@ jobs: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi         | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|   | ||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							| @@ -21,6 +21,13 @@ jobs: | ||||
|       - name: Check out the repo | ||||
|         uses: actions/checkout@v3 | ||||
|  | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|  | ||||
|       - name: Save version info | ||||
|         run: | | ||||
|           git describe --tags > VERSION  | ||||
|   | ||||
							
								
								
									
										6
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,12 @@ jobs: | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|   | ||||
							
								
								
									
										6
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -20,6 +20,12 @@ jobs: | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|   | ||||
							
								
								
									
										6
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -23,6 +23,12 @@ jobs: | ||||
|         uses: actions/checkout@v3 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - name: Check repository URL | ||||
|         run: | | ||||
|           REPO_URL=$(git config --get remote.origin.url) | ||||
|           if [[ $REPO_URL == *"pro" ]]; then | ||||
|             exit 1 | ||||
|           fi | ||||
|       - uses: actions/setup-node@v3 | ||||
|         with: | ||||
|           node-version: 16 | ||||
|   | ||||
| @@ -79,6 +79,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) | ||||
|    + [x] [MINIMAX](https://api.minimax.chat/) | ||||
|    + [x] [Groq](https://wow.groq.com/) | ||||
|    + [x] [Ollama](https://github.com/ollama/ollama) | ||||
|    + [x] [零一万物](https://platform.lingyiwanwu.com/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
| @@ -106,6 +108,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|     + [GitHub 开放授权](https://github.com/settings/applications/new)。 | ||||
|     + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 | ||||
| 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 | ||||
| 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 | ||||
|  | ||||
| ## 部署 | ||||
| ### 基于 Docker 进行部署 | ||||
| @@ -375,6 +378,9 @@ graph LR | ||||
| 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 | ||||
| 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 | ||||
| 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 | ||||
| 19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 | ||||
| 20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 | ||||
| 21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 | ||||
|  | ||||
| ### 命令行参数 | ||||
| 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| package config | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| @@ -70,17 +70,20 @@ var WeChatServerAddress = "" | ||||
| var WeChatServerToken = "" | ||||
| var WeChatAccountQRCodeImageURL = "" | ||||
|  | ||||
| var MessagePusherAddress = "" | ||||
| var MessagePusherToken = "" | ||||
|  | ||||
| var TurnstileSiteKey = "" | ||||
| var TurnstileSecretKey = "" | ||||
|  | ||||
| var QuotaForNewUser = 0 | ||||
| var QuotaForInviter = 0 | ||||
| var QuotaForInvitee = 0 | ||||
| var QuotaForNewUser int64 = 0 | ||||
| var QuotaForInviter int64 = 0 | ||||
| var QuotaForInvitee int64 = 0 | ||||
| var ChannelDisableThreshold = 5.0 | ||||
| var AutomaticDisableChannelEnabled = false | ||||
| var AutomaticEnableChannelEnabled = false | ||||
| var QuotaRemindThreshold = 1000 | ||||
| var PreConsumedQuota = 500 | ||||
| var QuotaRemindThreshold int64 = 1000 | ||||
| var PreConsumedQuota int64 = 500 | ||||
| var ApproximateTokenEnabled = false | ||||
| var RetryTimes = 0 | ||||
|  | ||||
| @@ -91,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" | ||||
| var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) | ||||
| var RequestInterval = time.Duration(requestInterval) * time.Second | ||||
|  | ||||
| var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second | ||||
| var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second | ||||
|  | ||||
| var BatchUpdateEnabled = false | ||||
| var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) | ||||
| var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) | ||||
|  | ||||
| var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second | ||||
| var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second | ||||
|  | ||||
| var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||
| var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") | ||||
|  | ||||
| var Theme = helper.GetOrDefaultEnvString("THEME", "default") | ||||
| var Theme = env.String("THEME", "default") | ||||
| var ValidThemes = map[string]bool{ | ||||
| 	"default": true, | ||||
| 	"berry":   true, | ||||
| @@ -109,10 +112,10 @@ var ValidThemes = map[string]bool{ | ||||
| // All duration's unit is seconds | ||||
| // Shouldn't larger then RateLimitKeyExpirationDuration | ||||
| var ( | ||||
| 	GlobalApiRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitNum            = env.Int("GLOBAL_API_RATE_LIMIT", 180) | ||||
| 	GlobalApiRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	GlobalWebRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitNum            = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) | ||||
| 	GlobalWebRateLimitDuration int64 = 3 * 60 | ||||
|  | ||||
| 	UploadRateLimitNum            = 10 | ||||
| @@ -126,3 +129,9 @@ var ( | ||||
| ) | ||||
|  | ||||
| var RateLimitKeyExpirationDuration = 20 * time.Minute | ||||
|  | ||||
| var EnableMetric = env.Bool("ENABLE_METRIC", false) | ||||
| var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) | ||||
| var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) | ||||
| var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) | ||||
| var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) | ||||
|   | ||||
| @@ -69,6 +69,8 @@ const ( | ||||
| 	ChannelTypeMinimax | ||||
| 	ChannelTypeMistral | ||||
| 	ChannelTypeGroq | ||||
| 	ChannelTypeOllama | ||||
| 	ChannelTypeLingYiWanWu | ||||
|  | ||||
| 	ChannelTypeDummy | ||||
| ) | ||||
| @@ -104,6 +106,8 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api.minimax.chat",                  // 27 | ||||
| 	"https://api.mistral.ai",                    // 28 | ||||
| 	"https://api.groq.com/openai",               // 29 | ||||
| 	"http://localhost:11434",                    // 30 | ||||
| 	"https://api.lingyiwanwu.com",               // 31 | ||||
| } | ||||
|  | ||||
| const ( | ||||
|   | ||||
| @@ -1,9 +1,12 @@ | ||||
| package common | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/common/helper" | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| ) | ||||
|  | ||||
| var UsingSQLite = false | ||||
| var UsingPostgreSQL = false | ||||
| var UsingMySQL = false | ||||
|  | ||||
| var SQLitePath = "one-api.db" | ||||
| var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) | ||||
| var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) | ||||
|   | ||||
							
								
								
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package env | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| func Bool(env string, defaultValue bool) bool { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) == "true" | ||||
| } | ||||
|  | ||||
| func Int(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.Atoi(os.Getenv(env)) | ||||
| 	if err != nil { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func Float64(env string, defaultValue float64) float64 { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.ParseFloat(os.Getenv(env), 64) | ||||
| 	if err != nil { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func String(env string, defaultValue string) string { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) | ||||
| } | ||||
| @@ -3,12 +3,10 @@ package helper | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| @@ -187,6 +185,10 @@ func GetTimeString() string { | ||||
| 	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) | ||||
| } | ||||
|  | ||||
| func GenRequestID() string { | ||||
| 	return GetTimeString() + GetRandomNumberString(8) | ||||
| } | ||||
|  | ||||
| func Max(a int, b int) int { | ||||
| 	if a >= b { | ||||
| 		return a | ||||
| @@ -195,25 +197,6 @@ func Max(a int, b int) int { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvInt(env string, defaultValue int) int { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	num, err := strconv.Atoi(os.Getenv(env)) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return num | ||||
| } | ||||
|  | ||||
| func GetOrDefaultEnvString(env string, defaultValue string) string { | ||||
| 	if env == "" || os.Getenv(env) == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return os.Getenv(env) | ||||
| } | ||||
|  | ||||
| func AssignOrDefault(value string, defaultValue string) string { | ||||
| 	if len(value) != 0 { | ||||
| 		return value | ||||
|   | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| @@ -19,9 +21,6 @@ const ( | ||||
| 	loggerError = "ERR" | ||||
| ) | ||||
|  | ||||
| const maxLogCount = 1000000 | ||||
|  | ||||
| var logCount int | ||||
| var setupLogLock sync.Mutex | ||||
| var setupLogWorking bool | ||||
|  | ||||
| @@ -57,7 +56,9 @@ func SysError(s string) { | ||||
| } | ||||
|  | ||||
| func Debug(ctx context.Context, msg string) { | ||||
| 	logHelper(ctx, loggerDEBUG, msg) | ||||
| 	if config.DebugEnabled { | ||||
| 		logHelper(ctx, loggerDEBUG, msg) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Info(ctx context.Context, msg string) { | ||||
| @@ -94,11 +95,12 @@ func logHelper(ctx context.Context, level string, msg string) { | ||||
| 		writer = gin.DefaultWriter | ||||
| 	} | ||||
| 	id := ctx.Value(RequestIdKey) | ||||
| 	if id == nil { | ||||
| 		id = helper.GenRequestID() | ||||
| 	} | ||||
| 	now := time.Now() | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||
| 	logCount++ // we don't need accurate count, so no lock here | ||||
| 	if logCount > maxLogCount && !setupLogWorking { | ||||
| 		logCount = 0 | ||||
| 	if !setupLogWorking { | ||||
| 		setupLogWorking = true | ||||
| 		go func() { | ||||
| 			SetupLogger() | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| package common | ||||
| package message | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| @@ -12,6 +12,9 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func SendEmail(subject string, receiver string, content string) error { | ||||
| 	if receiver == "" { | ||||
| 		return fmt.Errorf("receiver is empty") | ||||
| 	} | ||||
| 	if config.SMTPFrom == "" { // for compatibility | ||||
| 		config.SMTPFrom = config.SMTPAccount | ||||
| 	} | ||||
							
								
								
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								common/message/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package message | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	ByAll           = "all" | ||||
| 	ByEmail         = "email" | ||||
| 	ByMessagePusher = "message_pusher" | ||||
| ) | ||||
|  | ||||
| func Notify(by string, title string, description string, content string) error { | ||||
| 	if by == ByEmail { | ||||
| 		return SendEmail(title, config.RootUserEmail, content) | ||||
| 	} | ||||
| 	if by == ByMessagePusher { | ||||
| 		return SendMessage(title, description, content) | ||||
| 	} | ||||
| 	return fmt.Errorf("unknown notify method: %s", by) | ||||
| } | ||||
							
								
								
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								common/message/message-pusher.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| package message | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type request struct { | ||||
| 	Title       string `json:"title"` | ||||
| 	Description string `json:"description"` | ||||
| 	Content     string `json:"content"` | ||||
| 	URL         string `json:"url"` | ||||
| 	Channel     string `json:"channel"` | ||||
| 	Token       string `json:"token"` | ||||
| } | ||||
|  | ||||
| type response struct { | ||||
| 	Success bool   `json:"success"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| func SendMessage(title string, description string, content string) error { | ||||
| 	if config.MessagePusherAddress == "" { | ||||
| 		return errors.New("message pusher address is not set") | ||||
| 	} | ||||
| 	req := request{ | ||||
| 		Title:       title, | ||||
| 		Description: description, | ||||
| 		Content:     content, | ||||
| 		Token:       config.MessagePusherToken, | ||||
| 	} | ||||
| 	data, err := json.Marshal(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	resp, err := http.Post(config.MessagePusherAddress, | ||||
| 		"application/json", bytes.NewBuffer(data)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	var res response | ||||
| 	err = json.NewDecoder(resp.Body).Decode(&res) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !res.Success { | ||||
| 		return errors.New(res.Message) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -69,7 +69,7 @@ var ModelRatio = map[string]float64{ | ||||
| 	"claude-instant-1.2":       0.8 / 1000 * USD, | ||||
| 	"claude-2.0":               8.0 / 1000 * USD, | ||||
| 	"claude-2.1":               8.0 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240229":  0.25 / 1000 * USD, | ||||
| 	"claude-3-haiku-20240307":  0.25 / 1000 * USD, | ||||
| 	"claude-3-sonnet-20240229": 3.0 / 1000 * USD, | ||||
| 	"claude-3-opus-20240229":   15.0 / 1000 * USD, | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 | ||||
| @@ -78,6 +78,9 @@ var ModelRatio = map[string]float64{ | ||||
| 	"ERNIE-Bot-4":       0.12 * RMB, // ¥0.12 / 1k tokens | ||||
| 	"ERNIE-Bot-8k":      0.024 * RMB, | ||||
| 	"Embedding-V1":      0.1429, // ¥0.002 / 1k tokens | ||||
| 	"bge-large-zh":      0.002 * RMB, | ||||
| 	"bge-large-en":      0.002 * RMB, | ||||
| 	"bge-large-8k":      0.002 * RMB, | ||||
| 	"PaLM-2":            1, | ||||
| 	"gemini-pro":        1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| 	"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens | ||||
| @@ -130,6 +133,10 @@ var ModelRatio = map[string]float64{ | ||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||
| 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||
| 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||
| 	// https://platform.lingyiwanwu.com/docs#-计费单元 | ||||
| 	"yi-34b-chat-0205": 2.5 / 1000000 * RMB, | ||||
| 	"yi-34b-chat-200k": 12.0 / 1000000 * RMB, | ||||
| 	"yi-vl-plus":       6.0 / 1000000 * RMB, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
| @@ -148,6 +155,26 @@ func init() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func AddNewMissingRatio(oldRatio string) string { | ||||
| 	newRatio := make(map[string]float64) | ||||
| 	err := json.Unmarshal([]byte(oldRatio), &newRatio) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error unmarshalling old ratio: " + err.Error()) | ||||
| 		return oldRatio | ||||
| 	} | ||||
| 	for k, v := range DefaultModelRatio { | ||||
| 		if _, ok := newRatio[k]; !ok { | ||||
| 			newRatio[k] = v | ||||
| 		} | ||||
| 	} | ||||
| 	jsonBytes, err := json.Marshal(newRatio) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error marshalling new ratio: " + err.Error()) | ||||
| 		return oldRatio | ||||
| 	} | ||||
| 	return string(jsonBytes) | ||||
| } | ||||
|  | ||||
| func ModelRatio2JSONString() string { | ||||
| 	jsonBytes, err := json.Marshal(ModelRatio) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -5,7 +5,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| func LogQuota(quota int) string { | ||||
| func LogQuota(quota int64) string { | ||||
| 	if config.DisplayInCurrencyEnabled { | ||||
| 		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) | ||||
| 	} else { | ||||
|   | ||||
| @@ -8,8 +8,8 @@ import ( | ||||
| ) | ||||
|  | ||||
| func GetSubscription(c *gin.Context) { | ||||
| 	var remainQuota int | ||||
| 	var usedQuota int | ||||
| 	var remainQuota int64 | ||||
| 	var usedQuota int64 | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	var expiredTime int64 | ||||
| @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetUsage(c *gin.Context) { | ||||
| 	var quota int | ||||
| 	var quota int64 | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -295,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func updateAllChannelsBalance() error { | ||||
| 	channels, err := model.GetAllChannels(0, 0, true) | ||||
| 	channels, err := model.GetAllChannels(0, 0, "all") | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { | ||||
| 		} else { | ||||
| 			// err is nil & balance <= 0 means quota is used up | ||||
| 			if balance <= 0 { | ||||
| 				disableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, "余额不足") | ||||
| 			} | ||||
| 		} | ||||
| 		time.Sleep(config.RequestInterval) | ||||
| @@ -322,15 +323,14 @@ func updateAllChannelsBalance() error { | ||||
| } | ||||
|  | ||||
| func UpdateAllChannelsBalance(c *gin.Context) { | ||||
| 	// TODO: make it async | ||||
| 	err := updateAllChannelsBalance() | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| 			"message": err.Error(), | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	//err := updateAllChannelsBalance() | ||||
| 	//if err != nil { | ||||
| 	//	c.JSON(http.StatusOK, gin.H{ | ||||
| 	//		"success": false, | ||||
| 	//		"message": err.Error(), | ||||
| 	//	}) | ||||
| 	//	return | ||||
| 	//} | ||||
| 	c.JSON(http.StatusOK, gin.H{ | ||||
| 		"success": true, | ||||
| 		"message": "", | ||||
|   | ||||
| @@ -8,8 +8,10 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/helper" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -28,7 +30,7 @@ import ( | ||||
|  | ||||
| func buildTestRequest() *relaymodel.GeneralOpenAIRequest { | ||||
| 	testRequest := &relaymodel.GeneralOpenAIRequest{ | ||||
| 		MaxTokens: 1, | ||||
| 		MaxTokens: 2, | ||||
| 		Stream:    false, | ||||
| 		Model:     "gpt-3.5-turbo", | ||||
| 	} | ||||
| @@ -148,33 +150,7 @@ func TestChannel(c *gin.Context) { | ||||
| var testAllChannelsLock sync.Mutex | ||||
| var testAllChannelsRunning bool = false | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := common.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // disable & notify | ||||
| func disableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // enable & notify | ||||
| func enableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func testAllChannels(notify bool) error { | ||||
| func testChannels(notify bool, scope string) error { | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| @@ -185,7 +161,7 @@ func testAllChannels(notify bool) error { | ||||
| 	} | ||||
| 	testAllChannelsRunning = true | ||||
| 	testAllChannelsLock.Unlock() | ||||
| 	channels, err := model.GetAllChannels(0, 0, true) | ||||
| 	channels, err := model.GetAllChannels(0, 0, scope) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -202,13 +178,17 @@ func testAllChannels(notify bool) error { | ||||
| 			milliseconds := tok.Sub(tik).Milliseconds() | ||||
| 			if isChannelEnabled && milliseconds > disableThreshold { | ||||
| 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				if config.AutomaticDisableChannelEnabled { | ||||
| 					monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				} else { | ||||
| 					_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 			if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { | ||||
| 				disableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 				monitor.DisableChannel(channel.Id, channel.Name, err.Error()) | ||||
| 			} | ||||
| 			if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { | ||||
| 				enableChannel(channel.Id, channel.Name) | ||||
| 				monitor.EnableChannel(channel.Id, channel.Name) | ||||
| 			} | ||||
| 			channel.UpdateResponseTime(milliseconds) | ||||
| 			time.Sleep(config.RequestInterval) | ||||
| @@ -217,7 +197,7 @@ func testAllChannels(notify bool) error { | ||||
| 		testAllChannelsRunning = false | ||||
| 		testAllChannelsLock.Unlock() | ||||
| 		if notify { | ||||
| 			err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | ||||
| 			err := message.Notify(message.ByAll, "通道测试完成", "", "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") | ||||
| 			if err != nil { | ||||
| 				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 			} | ||||
| @@ -226,8 +206,12 @@ func testAllChannels(notify bool) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TestAllChannels(c *gin.Context) { | ||||
| 	err := testAllChannels(true) | ||||
| func TestChannels(c *gin.Context) { | ||||
| 	scope := c.Query("scope") | ||||
| 	if scope == "" { | ||||
| 		scope = "all" | ||||
| 	} | ||||
| 	err := testChannels(true, scope) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -246,7 +230,7 @@ func AutomaticallyTestChannels(frequency int) { | ||||
| 	for { | ||||
| 		time.Sleep(time.Duration(frequency) * time.Minute) | ||||
| 		logger.SysLog("testing all channels") | ||||
| 		_ = testAllChannels(false) | ||||
| 		_ = testChannels(false, "all") | ||||
| 		logger.SysLog("channel test finished") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) | ||||
| 	channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited") | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| @@ -110,7 +111,7 @@ func SendEmailVerification(c *gin.Context) { | ||||
| 	content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ | ||||
| 		"<p>您的验证码为: <strong>%s</strong></p>"+ | ||||
| 		"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	err := message.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -149,7 +150,7 @@ func SendPasswordResetEmail(c *gin.Context) { | ||||
| 		"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ | ||||
| 		"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ | ||||
| 		"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) | ||||
| 	err := common.SendEmail(subject, email, content) | ||||
| 	err := message.SendEmail(subject, email, content) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| 	dbmodel "github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/monitor" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/controller" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -45,11 +46,12 @@ func Relay(c *gin.Context) { | ||||
| 		requestBody, _ := common.GetRequestBody(c) | ||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||
| 	} | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	bizErr := relay(c, relayMode) | ||||
| 	if bizErr == nil { | ||||
| 		monitor.Emit(channelId, true) | ||||
| 		return | ||||
| 	} | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	lastFailedChannelId := channelId | ||||
| 	channelName := c.GetString("channel_name") | ||||
| 	group := c.GetString("group") | ||||
| @@ -117,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st | ||||
| 	logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) | ||||
| 	// https://platform.openai.com/docs/guides/error-codes/api-errors | ||||
| 	if util.ShouldDisableChannel(&err.Error, err.StatusCode) { | ||||
| 		disableChannel(channelId, channelName, err.Message) | ||||
| 		monitor.DisableChannel(channelId, channelName, err.Message) | ||||
| 	} else { | ||||
| 		monitor.Emit(channelId, false) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										10
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								main.go
									
									
									
									
									
								
							| @@ -64,13 +64,6 @@ func main() { | ||||
| 		go model.SyncOptions(config.SyncFrequency) | ||||
| 		go model.SyncChannelCache(config.SyncFrequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| 			logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) | ||||
| 		} | ||||
| 		go controller.AutomaticallyUpdateChannels(frequency) | ||||
| 	} | ||||
| 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { | ||||
| 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) | ||||
| 		if err != nil { | ||||
| @@ -83,6 +76,9 @@ func main() { | ||||
| 		logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") | ||||
| 		model.InitBatchUpdater() | ||||
| 	} | ||||
| 	if config.EnableMetric { | ||||
| 		logger.SysLog("metric enabled, will disable channel if too much request failed") | ||||
| 	} | ||||
| 	openai.InitTokenEncoders() | ||||
|  | ||||
| 	// Initialize HTTP server | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package middleware | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"net/http" | ||||
| 	"runtime/debug" | ||||
| @@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		defer func() { | ||||
| 			if err := recover(); err != nil { | ||||
| 				logger.SysError(fmt.Sprintf("panic detected: %v", err)) | ||||
| 				logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||
| 				ctx := c.Request.Context() | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) | ||||
| 				body, _ := common.GetRequestBody(c) | ||||
| 				logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) | ||||
| 				c.JSON(http.StatusInternalServerError, gin.H{ | ||||
| 					"error": gin.H{ | ||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), | ||||
| 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), | ||||
| 						"type":    "one_api_panic", | ||||
| 					}, | ||||
| 				}) | ||||
|   | ||||
| @@ -9,7 +9,7 @@ import ( | ||||
|  | ||||
| func RequestId() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		id := helper.GetTimeString() + helper.GetRandomNumberString(8) | ||||
| 		id := helper.GenRequestID() | ||||
| 		c.Set(logger.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| @@ -70,31 +71,42 @@ func CacheGetUserGroup(id int) (group string, err error) { | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
| func CacheGetUserQuota(id int) (quota int, err error) { | ||||
| func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||
| 	quota, err = GetUserQuota(id) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "Redis set user quota error: "+err.Error()) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return GetUserQuota(id) | ||||
| 	} | ||||
| 	quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) | ||||
| 	if err != nil { | ||||
| 		quota, err = GetUserQuota(id) | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("Redis set user quota error: " + err.Error()) | ||||
| 		} | ||||
| 		return quota, err | ||||
| 		return fetchAndUpdateUserQuota(ctx, id) | ||||
| 	} | ||||
| 	quota, err = strconv.Atoi(quotaString) | ||||
| 	return quota, err | ||||
| 	quota, err = strconv.ParseInt(quotaString, 10, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, nil | ||||
| 	} | ||||
| 	if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db | ||||
| 		logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) | ||||
| 		return fetchAndUpdateUserQuota(ctx, id) | ||||
| 	} | ||||
| 	return quota, nil | ||||
| } | ||||
|  | ||||
| func CacheUpdateUserQuota(id int) error { | ||||
| func CacheUpdateUserQuota(ctx context.Context, id int) error { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return nil | ||||
| 	} | ||||
| 	quota, err := CacheGetUserQuota(id) | ||||
| 	quota, err := CacheGetUserQuota(ctx, id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -102,7 +114,7 @@ func CacheUpdateUserQuota(id int) error { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func CacheDecreaseUserQuota(id int, quota int) error { | ||||
| func CacheDecreaseUserQuota(id int, quota int64) error { | ||||
| 	if !common.RedisEnabled { | ||||
| 		return nil | ||||
| 	} | ||||
|   | ||||
| @@ -13,7 +13,7 @@ import ( | ||||
| type Channel struct { | ||||
| 	Id                 int     `json:"id"` | ||||
| 	Type               int     `json:"type" gorm:"default:0"` | ||||
| 	Key                string  `json:"key" gorm:"not null;index"` | ||||
| 	Key                string  `json:"key" gorm:"type:text"` | ||||
| 	Status             int     `json:"status" gorm:"default:1"` | ||||
| 	Name               string  `json:"name" gorm:"index"` | ||||
| 	Weight             *uint   `json:"weight" gorm:"default:0"` | ||||
| @@ -32,23 +32,22 @@ type Channel struct { | ||||
| 	Config             string  `json:"config"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { | ||||
| 	var channels []*Channel | ||||
| 	var err error | ||||
| 	if selectAll { | ||||
| 	switch scope { | ||||
| 	case "all": | ||||
| 		err = DB.Order("id desc").Find(&channels).Error | ||||
| 	} else { | ||||
| 	case "disabled": | ||||
| 		err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error | ||||
| 	default: | ||||
| 		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | ||||
| 	} | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| func SearchChannels(keyword string) (channels []*Channel, err error) { | ||||
| 	keyCol := "`key`" | ||||
| 	if common.UsingPostgreSQL { | ||||
| 		keyCol = `"key"` | ||||
| 	} | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error | ||||
| 	err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error | ||||
| 	return channels, err | ||||
| } | ||||
|  | ||||
| @@ -179,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func UpdateChannelUsedQuota(id int, quota int) { | ||||
| func UpdateChannelUsedQuota(id int, quota int64) { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | ||||
| 		return | ||||
| @@ -187,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) { | ||||
| 	updateChannelUsedQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func updateChannelUsedQuota(id int, quota int) { | ||||
| func updateChannelUsedQuota(id int, quota int64) { | ||||
| 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | ||||
| 	if err != nil { | ||||
| 		logger.SysError("failed to update channel used quota: " + err.Error()) | ||||
|   | ||||
| @@ -51,7 +51,7 @@ func RecordLog(userId int, logType int, content string) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { | ||||
| func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { | ||||
| 	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) | ||||
| 	if !config.LogConsumeEnabled { | ||||
| 		return | ||||
| @@ -66,7 +66,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke | ||||
| 		CompletionTokens: completionTokens, | ||||
| 		TokenName:        tokenName, | ||||
| 		ModelName:        modelName, | ||||
| 		Quota:            quota, | ||||
| 		Quota:            int(quota), | ||||
| 		ChannelId:        channelId, | ||||
| 	} | ||||
| 	err := DB.Create(log).Error | ||||
| @@ -137,7 +137,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { | ||||
| 	return logs, err | ||||
| } | ||||
|  | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { | ||||
| func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { | ||||
| 	tx := DB.Table("logs").Select("ifnull(sum(quota),0)") | ||||
| 	if username != "" { | ||||
| 		tx = tx.Where("username = ?", username) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/env" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"gorm.io/driver/mysql" | ||||
| @@ -56,6 +57,7 @@ func chooseDB() (*gorm.DB, error) { | ||||
| 		} | ||||
| 		// Use MySQL | ||||
| 		logger.SysLog("using MySQL as database") | ||||
| 		common.UsingMySQL = true | ||||
| 		return gorm.Open(mysql.Open(dsn), &gorm.Config{ | ||||
| 			PrepareStmt: true, // precompile SQL | ||||
| 		}) | ||||
| @@ -80,13 +82,16 @@ func InitDB() (err error) { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) | ||||
| 		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) | ||||
| 		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) | ||||
| 		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) | ||||
|  | ||||
| 		if !config.IsMasterNode { | ||||
| 			return nil | ||||
| 		} | ||||
| 		if common.UsingMySQL { | ||||
| 			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded | ||||
| 		} | ||||
| 		logger.SysLog("database migration started") | ||||
| 		err = db.AutoMigrate(&Channel{}) | ||||
| 		if err != nil { | ||||
|   | ||||
| @@ -57,13 +57,15 @@ func InitOptionMap() { | ||||
| 	config.OptionMap["WeChatServerAddress"] = "" | ||||
| 	config.OptionMap["WeChatServerToken"] = "" | ||||
| 	config.OptionMap["WeChatAccountQRCodeImageURL"] = "" | ||||
| 	config.OptionMap["MessagePusherAddress"] = "" | ||||
| 	config.OptionMap["MessagePusherToken"] = "" | ||||
| 	config.OptionMap["TurnstileSiteKey"] = "" | ||||
| 	config.OptionMap["TurnstileSecretKey"] = "" | ||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) | ||||
| 	config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) | ||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) | ||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) | ||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) | ||||
| 	config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) | ||||
| 	config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) | ||||
| 	config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) | ||||
| 	config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) | ||||
| 	config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) | ||||
| 	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() | ||||
| 	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() | ||||
| 	config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() | ||||
| @@ -79,6 +81,9 @@ func InitOptionMap() { | ||||
| func loadOptionsFromDatabase() { | ||||
| 	options, _ := AllOption() | ||||
| 	for _, option := range options { | ||||
| 		if option.Key == "ModelRatio" { | ||||
| 			option.Value = common.AddNewMissingRatio(option.Value) | ||||
| 		} | ||||
| 		err := updateOptionMap(option.Key, option.Value) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("failed to update option map: " + err.Error()) | ||||
| @@ -179,20 +184,24 @@ func updateOptionMap(key string, value string) (err error) { | ||||
| 		config.WeChatServerToken = value | ||||
| 	case "WeChatAccountQRCodeImageURL": | ||||
| 		config.WeChatAccountQRCodeImageURL = value | ||||
| 	case "MessagePusherAddress": | ||||
| 		config.MessagePusherAddress = value | ||||
| 	case "MessagePusherToken": | ||||
| 		config.MessagePusherToken = value | ||||
| 	case "TurnstileSiteKey": | ||||
| 		config.TurnstileSiteKey = value | ||||
| 	case "TurnstileSecretKey": | ||||
| 		config.TurnstileSecretKey = value | ||||
| 	case "QuotaForNewUser": | ||||
| 		config.QuotaForNewUser, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "QuotaForInviter": | ||||
| 		config.QuotaForInviter, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "QuotaForInvitee": | ||||
| 		config.QuotaForInvitee, _ = strconv.Atoi(value) | ||||
| 		config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "QuotaRemindThreshold": | ||||
| 		config.QuotaRemindThreshold, _ = strconv.Atoi(value) | ||||
| 		config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "PreConsumedQuota": | ||||
| 		config.PreConsumedQuota, _ = strconv.Atoi(value) | ||||
| 		config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) | ||||
| 	case "RetryTimes": | ||||
| 		config.RetryTimes, _ = strconv.Atoi(value) | ||||
| 	case "ModelRatio": | ||||
|   | ||||
| @@ -14,7 +14,7 @@ type Redemption struct { | ||||
| 	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"` | ||||
| 	Status       int    `json:"status" gorm:"default:1"` | ||||
| 	Name         string `json:"name" gorm:"index"` | ||||
| 	Quota        int    `json:"quota" gorm:"default:100"` | ||||
| 	Quota        int64  `json:"quota" gorm:"default:100"` | ||||
| 	CreatedTime  int64  `json:"created_time" gorm:"bigint"` | ||||
| 	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"` | ||||
| 	Count        int    `json:"count" gorm:"-:all"` // only for api request | ||||
| @@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) { | ||||
| 	return &redemption, err | ||||
| } | ||||
|  | ||||
| func Redeem(key string, userId int) (quota int, err error) { | ||||
| func Redeem(key string, userId int) (quota int64, err error) { | ||||
| 	if key == "" { | ||||
| 		return 0, errors.New("未提供兑换码") | ||||
| 	} | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| @@ -19,9 +20,9 @@ type Token struct { | ||||
| 	CreatedTime    int64  `json:"created_time" gorm:"bigint"` | ||||
| 	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"` | ||||
| 	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired | ||||
| 	RemainQuota    int    `json:"remain_quota" gorm:"default:0"` | ||||
| 	RemainQuota    int64  `json:"remain_quota" gorm:"default:0"` | ||||
| 	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"` | ||||
| 	UsedQuota      int    `json:"used_quota" gorm:"default:0"` // used quota | ||||
| 	UsedQuota      int64  `json:"used_quota" gorm:"default:0"` // used quota | ||||
| } | ||||
|  | ||||
| func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { | ||||
| @@ -137,7 +138,7 @@ func DeleteTokenById(id int, userId int) (err error) { | ||||
| 	return token.Delete() | ||||
| } | ||||
|  | ||||
| func IncreaseTokenQuota(id int, quota int) (err error) { | ||||
| func IncreaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -148,7 +149,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return increaseTokenQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func increaseTokenQuota(id int, quota int) (err error) { | ||||
| func increaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"remain_quota":  gorm.Expr("remain_quota + ?", quota), | ||||
| @@ -159,7 +160,7 @@ func increaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func DecreaseTokenQuota(id int, quota int) (err error) { | ||||
| func DecreaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -170,7 +171,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return decreaseTokenQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func decreaseTokenQuota(id int, quota int) (err error) { | ||||
| func decreaseTokenQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&Token{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"remain_quota":  gorm.Expr("remain_quota - ?", quota), | ||||
| @@ -181,7 +182,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -213,7 +214,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| 			} | ||||
| 			if email != "" { | ||||
| 				topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) | ||||
| 				err = common.SendEmail(prompt, email, | ||||
| 				err = message.SendEmail(prompt, email, | ||||
| 					fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("failed to send email" + err.Error()) | ||||
| @@ -231,7 +232,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func PostConsumeTokenQuota(tokenId int, quota int) (err error) { | ||||
| func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { | ||||
| 	token, err := GetTokenById(tokenId) | ||||
| 	if quota > 0 { | ||||
| 		err = DecreaseUserQuota(token.UserId, quota) | ||||
|   | ||||
| @@ -26,8 +26,8 @@ type User struct { | ||||
| 	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"` | ||||
| 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database! | ||||
| 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management | ||||
| 	Quota            int    `json:"quota" gorm:"type:int;default:0"` | ||||
| 	UsedQuota        int    `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota | ||||
| 	Quota            int64  `json:"quota" gorm:"type:int;default:0"` | ||||
| 	UsedQuota        int64  `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota | ||||
| 	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`               // request number | ||||
| 	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||
| 	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` | ||||
| @@ -274,12 +274,12 @@ func ValidateAccessToken(token string) (user *User) { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func GetUserQuota(id int) (quota int, err error) { | ||||
| func GetUserQuota(id int) (quota int64, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error | ||||
| 	return quota, err | ||||
| } | ||||
|  | ||||
| func GetUserUsedQuota(id int) (quota int, err error) { | ||||
| func GetUserUsedQuota(id int) (quota int64, err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error | ||||
| 	return quota, err | ||||
| } | ||||
| @@ -299,7 +299,7 @@ func GetUserGroup(id int) (group string, err error) { | ||||
| 	return group, err | ||||
| } | ||||
|  | ||||
| func IncreaseUserQuota(id int, quota int) (err error) { | ||||
| func IncreaseUserQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -310,12 +310,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { | ||||
| 	return increaseUserQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func increaseUserQuota(id int, quota int) (err error) { | ||||
| func increaseUserQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func DecreaseUserQuota(id int, quota int) (err error) { | ||||
| func DecreaseUserQuota(id int, quota int64) (err error) { | ||||
| 	if quota < 0 { | ||||
| 		return errors.New("quota 不能为负数!") | ||||
| 	} | ||||
| @@ -326,7 +326,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { | ||||
| 	return decreaseUserQuota(id, quota) | ||||
| } | ||||
|  | ||||
| func decreaseUserQuota(id int, quota int) (err error) { | ||||
| func decreaseUserQuota(id int, quota int64) (err error) { | ||||
| 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error | ||||
| 	return err | ||||
| } | ||||
| @@ -336,7 +336,7 @@ func GetRootUserEmail() (email string) { | ||||
| 	return email | ||||
| } | ||||
|  | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { | ||||
| 	if config.BatchUpdateEnabled { | ||||
| 		addNewRecord(BatchUpdateTypeUsedQuota, id, quota) | ||||
| 		addNewRecord(BatchUpdateTypeRequestCount, id, 1) | ||||
| @@ -345,7 +345,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { | ||||
| 	updateUserUsedQuotaAndRequestCount(id, quota, 1) | ||||
| } | ||||
|  | ||||
| func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"used_quota":    gorm.Expr("used_quota + ?", quota), | ||||
| @@ -357,7 +357,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func updateUserUsedQuota(id int, quota int) { | ||||
| func updateUserUsedQuota(id int, quota int64) { | ||||
| 	err := DB.Model(&User{}).Where("id = ?", id).Updates( | ||||
| 		map[string]interface{}{ | ||||
| 			"used_quota": gorm.Expr("used_quota + ?", quota), | ||||
|   | ||||
| @@ -16,12 +16,12 @@ const ( | ||||
| 	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock | ||||
| ) | ||||
|  | ||||
| var batchUpdateStores []map[int]int | ||||
| var batchUpdateStores []map[int]int64 | ||||
| var batchUpdateLocks []sync.Mutex | ||||
|  | ||||
| func init() { | ||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||
| 		batchUpdateStores = append(batchUpdateStores, make(map[int]int)) | ||||
| 		batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) | ||||
| 		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) | ||||
| 	} | ||||
| } | ||||
| @@ -35,7 +35,7 @@ func InitBatchUpdater() { | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func addNewRecord(type_ int, id int, value int) { | ||||
| func addNewRecord(type_ int, id int, value int64) { | ||||
| 	batchUpdateLocks[type_].Lock() | ||||
| 	defer batchUpdateLocks[type_].Unlock() | ||||
| 	if _, ok := batchUpdateStores[type_][id]; !ok { | ||||
| @@ -50,7 +50,7 @@ func batchUpdate() { | ||||
| 	for i := 0; i < BatchUpdateTypeCount; i++ { | ||||
| 		batchUpdateLocks[i].Lock() | ||||
| 		store := batchUpdateStores[i] | ||||
| 		batchUpdateStores[i] = make(map[int]int) | ||||
| 		batchUpdateStores[i] = make(map[int]int64) | ||||
| 		batchUpdateLocks[i].Unlock() | ||||
| 		// TODO: maybe we can combine updates with same key? | ||||
| 		for key, value := range store { | ||||
| @@ -68,7 +68,7 @@ func batchUpdate() { | ||||
| 			case BatchUpdateTypeUsedQuota: | ||||
| 				updateUserUsedQuota(key, value) | ||||
| 			case BatchUpdateTypeRequestCount: | ||||
| 				updateUserRequestCount(key, value) | ||||
| 				updateUserRequestCount(key, int(value)) | ||||
| 			case BatchUpdateTypeChannelUsedQuota: | ||||
| 				updateChannelUsedQuota(key, value) | ||||
| 			} | ||||
|   | ||||
							
								
								
									
										55
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								monitor/channel.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
|  | ||||
| func notifyRootUser(subject string, content string) { | ||||
| 	if config.MessagePusherAddress != "" { | ||||
| 		err := message.SendMessage(subject, content, content) | ||||
| 		if err != nil { | ||||
| 			logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error())) | ||||
| 		} else { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if config.RootUserEmail == "" { | ||||
| 		config.RootUserEmail = model.GetRootUserEmail() | ||||
| 	} | ||||
| 	err := message.SendEmail(subject, config.RootUserEmail, content) | ||||
| 	if err != nil { | ||||
| 		logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // DisableChannel disable & notify | ||||
| func DisableChannel(channelId int, channelName string, reason string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| func MetricDisableChannel(channelId int, successRate float64) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) | ||||
| 	subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) | ||||
| 	content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", | ||||
| 		config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
|  | ||||
| // EnableChannel enable & notify | ||||
| func EnableChannel(channelId int, channelName string) { | ||||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) | ||||
| 	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) | ||||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) | ||||
| 	notifyRootUser(subject, content) | ||||
| } | ||||
							
								
								
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								monitor/metric.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| package monitor | ||||
|  | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| var store = make(map[int][]bool) | ||||
| var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) | ||||
| var metricFailChan = make(chan int, config.MetricFailChanSize) | ||||
|  | ||||
| func consumeSuccess(channelId int) { | ||||
| 	if len(store[channelId]) > config.MetricQueueSize { | ||||
| 		store[channelId] = store[channelId][1:] | ||||
| 	} | ||||
| 	store[channelId] = append(store[channelId], true) | ||||
| } | ||||
|  | ||||
| func consumeFail(channelId int) (bool, float64) { | ||||
| 	if len(store[channelId]) > config.MetricQueueSize { | ||||
| 		store[channelId] = store[channelId][1:] | ||||
| 	} | ||||
| 	store[channelId] = append(store[channelId], false) | ||||
| 	successCount := 0 | ||||
| 	for _, success := range store[channelId] { | ||||
| 		if success { | ||||
| 			successCount++ | ||||
| 		} | ||||
| 	} | ||||
| 	successRate := float64(successCount) / float64(len(store[channelId])) | ||||
| 	if len(store[channelId]) < config.MetricQueueSize { | ||||
| 		return false, successRate | ||||
| 	} | ||||
| 	if successRate < config.MetricSuccessRateThreshold { | ||||
| 		store[channelId] = make([]bool, 0) | ||||
| 		return true, successRate | ||||
| 	} | ||||
| 	return false, successRate | ||||
| } | ||||
|  | ||||
| func metricSuccessConsumer() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case channelId := <-metricSuccessChan: | ||||
| 			consumeSuccess(channelId) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func metricFailConsumer() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case channelId := <-metricFailChan: | ||||
| 			disable, successRate := consumeFail(channelId) | ||||
| 			if disable { | ||||
| 				go MetricDisableChannel(channelId, successRate) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	if config.EnableMetric { | ||||
| 		go metricSuccessConsumer() | ||||
| 		go metricFailConsumer() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Emit(channelId int, success bool) { | ||||
| 	if !config.EnableMetric { | ||||
| 		return | ||||
| 	} | ||||
| 	go func() { | ||||
| 		if success { | ||||
| 			metricSuccessChan <- channelId | ||||
| 		} else { | ||||
| 			metricFailChan <- channelId | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
| @@ -32,6 +32,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| 	channel.SetupCommonRequestHeader(c, req, meta) | ||||
| 	if meta.IsStream { | ||||
| 		req.Header.Set("Accept", "text/event-stream") | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	if meta.IsStream { | ||||
| 		req.Header.Set("X-DashScope-SSE", "enable") | ||||
|   | ||||
| @@ -59,5 +59,5 @@ func (a *Adaptor) GetModelList() []string { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "authropic" | ||||
| 	return "anthropic" | ||||
| } | ||||
|   | ||||
| @@ -2,7 +2,7 @@ package anthropic | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"claude-instant-1.2", "claude-2.0", "claude-2.1", | ||||
| 	"claude-3-haiku-20240229", | ||||
| 	"claude-3-haiku-20240307", | ||||
| 	"claude-3-sonnet-20240229", | ||||
| 	"claude-3-opus-20240229", | ||||
| } | ||||
|   | ||||
| @@ -3,14 +3,15 @@ package baidu | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| @@ -23,7 +24,13 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t | ||||
| 	suffix := "chat/" | ||||
| 	if strings.HasPrefix("Embedding", meta.ActualModelName) { | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "Embedding") { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "bge-large") { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	if strings.HasPrefix(meta.ActualModelName, "tao-8k") { | ||||
| 		suffix = "embeddings/" | ||||
| 	} | ||||
| 	switch meta.ActualModelName { | ||||
| @@ -45,6 +52,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 		suffix += "bloomz_7b1" | ||||
| 	case "Embedding-V1": | ||||
| 		suffix += "embedding-v1" | ||||
| 	case "bge-large-zh": | ||||
| 		suffix += "bge_large_zh" | ||||
| 	case "bge-large-en": | ||||
| 		suffix += "bge_large_en" | ||||
| 	case "tao-8k": | ||||
| 		suffix += "tao_8k" | ||||
| 	default: | ||||
| 		suffix += meta.ActualModelName | ||||
| 	} | ||||
|   | ||||
| @@ -7,4 +7,7 @@ var ModelList = []string{ | ||||
| 	"ERNIE-Speed", | ||||
| 	"ERNIE-Bot-turbo", | ||||
| 	"Embedding-V1", | ||||
| 	"bge-large-zh", | ||||
| 	"bge-large-en", | ||||
| 	"tao-8k", | ||||
| } | ||||
|   | ||||
| @@ -32,9 +32,16 @@ type Message struct { | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Messages []Message `json:"messages"` | ||||
| 	Stream   bool      `json:"stream"` | ||||
| 	UserId   string    `json:"user_id,omitempty"` | ||||
| 	Messages        []Message `json:"messages"` | ||||
| 	Temperature     float64   `json:"temperature,omitempty"` | ||||
| 	TopP            float64   `json:"top_p,omitempty"` | ||||
| 	PenaltyScore    float64   `json:"penalty_score,omitempty"` | ||||
| 	Stream          bool      `json:"stream,omitempty"` | ||||
| 	System          string    `json:"system,omitempty"` | ||||
| 	DisableSearch   bool      `json:"disable_search,omitempty"` | ||||
| 	EnableCitation  bool      `json:"enable_citation,omitempty"` | ||||
| 	MaxOutputTokens int       `json:"max_output_tokens,omitempty"` | ||||
| 	UserId          string    `json:"user_id,omitempty"` | ||||
| } | ||||
|  | ||||
| type Error struct { | ||||
| @@ -45,28 +52,28 @@ type Error struct { | ||||
| var baiduTokenStore sync.Map | ||||
|  | ||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	messages := make([]Message, 0, len(request.Messages)) | ||||
| 	baiduRequest := ChatRequest{ | ||||
| 		Messages:        make([]Message, 0, len(request.Messages)), | ||||
| 		Temperature:     request.Temperature, | ||||
| 		TopP:            request.TopP, | ||||
| 		PenaltyScore:    request.FrequencyPenalty, | ||||
| 		Stream:          request.Stream, | ||||
| 		DisableSearch:   false, | ||||
| 		EnableCitation:  false, | ||||
| 		MaxOutputTokens: request.MaxTokens, | ||||
| 		UserId:          request.User, | ||||
| 	} | ||||
| 	for _, message := range request.Messages { | ||||
| 		if message.Role == "system" { | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "user", | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 			messages = append(messages, Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "Okay", | ||||
| 			}) | ||||
| 			baiduRequest.System = message.StringContent() | ||||
| 		} else { | ||||
| 			messages = append(messages, Message{ | ||||
| 			baiduRequest.Messages = append(baiduRequest.Messages, Message{ | ||||
| 				Role:    message.Role, | ||||
| 				Content: message.StringContent(), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &ChatRequest{ | ||||
| 		Messages: messages, | ||||
| 		Stream:   request.Stream, | ||||
| 	} | ||||
| 	return &baiduRequest | ||||
| } | ||||
|  | ||||
| func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
|   | ||||
							
								
								
									
										9
									
								
								relay/channel/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								relay/channel/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| package lingyiwanwu | ||||
|  | ||||
| // https://platform.lingyiwanwu.com/docs | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"yi-34b-chat-0205", | ||||
| 	"yi-34b-chat-200k", | ||||
| 	"yi-vl-plus", | ||||
| } | ||||
							
								
								
									
										65
									
								
								relay/channel/ollama/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								relay/channel/ollama/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,65 @@ | ||||
| package ollama | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/util" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *util.RelayMeta) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { | ||||
| 	// https://github.com/ollama/ollama/blob/main/docs/api.md | ||||
| 	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) | ||||
| 	return fullRequestURL, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { | ||||
| 	channel.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeEmbeddings: | ||||
| 		return nil, errors.New("not supported") | ||||
| 	default: | ||||
| 		return ConvertRequest(*request), nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return channel.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "ollama" | ||||
| } | ||||
							
								
								
									
										5
									
								
								relay/channel/ollama/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/channel/ollama/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package ollama | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"qwen:0.5b-chat", | ||||
| } | ||||
							
								
								
									
										178
									
								
								relay/channel/ollama/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								relay/channel/ollama/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| package ollama | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 	ollamaRequest := ChatRequest{ | ||||
| 		Model: request.Model, | ||||
| 		Options: &Options{ | ||||
| 			Seed:             int(request.Seed), | ||||
| 			Temperature:      request.Temperature, | ||||
| 			TopP:             request.TopP, | ||||
| 			FrequencyPenalty: request.FrequencyPenalty, | ||||
| 			PresencePenalty:  request.PresencePenalty, | ||||
| 		}, | ||||
| 		Stream: request.Stream, | ||||
| 	} | ||||
| 	for _, message := range request.Messages { | ||||
| 		ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.StringContent(), | ||||
| 		}) | ||||
| 	} | ||||
| 	return &ollamaRequest | ||||
| } | ||||
|  | ||||
| func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    response.Message.Role, | ||||
| 			Content: response.Message.Content, | ||||
| 		}, | ||||
| 	} | ||||
| 	if response.Done { | ||||
| 		choice.FinishReason = "stop" | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 		Usage: model.Usage{ | ||||
| 			PromptTokens:     response.PromptEvalCount, | ||||
| 			CompletionTokens: response.EvalCount, | ||||
| 			TotalTokens:      response.PromptEvalCount + response.EvalCount, | ||||
| 		}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Role = ollamaResponse.Message.Role | ||||
| 	choice.Delta.Content = ollamaResponse.Message.Content | ||||
| 	if ollamaResponse.Done { | ||||
| 		choice.FinishReason = &constant.StopFinishReason | ||||
| 	} | ||||
| 	response := openai.ChatCompletionsStreamResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Model:   ollamaResponse.Model, | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &response | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	var usage model.Usage | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "}\n"); i >= 0 { | ||||
| 			return i + 2, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := strings.TrimPrefix(scanner.Text(), "}") | ||||
| 			dataChan <- data + "}" | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			var ollamaResponse ChatResponse | ||||
| 			err := json.Unmarshal([]byte(data), &ollamaResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			if ollamaResponse.EvalCount != 0 { | ||||
| 				usage.PromptTokens = ollamaResponse.PromptEvalCount | ||||
| 				usage.CompletionTokens = ollamaResponse.EvalCount | ||||
| 				usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount | ||||
| 			} | ||||
| 			response := streamResponseOllama2OpenAI(&ollamaResponse) | ||||
| 			jsonResponse, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	err := resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	ctx := context.TODO() | ||||
| 	var ollamaResponse ChatResponse | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	logger.Debugf(ctx, "ollama response: %s", string(responseBody)) | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = json.Unmarshal(responseBody, &ollamaResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if ollamaResponse.Error != "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: ollamaResponse.Error, | ||||
| 				Type:    "ollama_error", | ||||
| 				Param:   "", | ||||
| 				Code:    "ollama_error", | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := responseOllama2OpenAI(&ollamaResponse) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &fullTextResponse.Usage | ||||
| } | ||||
							
								
								
									
										37
									
								
								relay/channel/ollama/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								relay/channel/ollama/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package ollama | ||||
|  | ||||
| type Options struct { | ||||
| 	Seed             int     `json:"seed,omitempty"` | ||||
| 	Temperature      float64 `json:"temperature,omitempty"` | ||||
| 	TopK             int     `json:"top_k,omitempty"` | ||||
| 	TopP             float64 `json:"top_p,omitempty"` | ||||
| 	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` | ||||
| 	PresencePenalty  float64 `json:"presence_penalty,omitempty"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string   `json:"role,omitempty"` | ||||
| 	Content string   `json:"content,omitempty"` | ||||
| 	Images  []string `json:"images,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatRequest struct { | ||||
| 	Model    string    `json:"model,omitempty"` | ||||
| 	Messages []Message `json:"messages,omitempty"` | ||||
| 	Stream   bool      `json:"stream"` | ||||
| 	Options  *Options  `json:"options,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatResponse struct { | ||||
| 	Model           string  `json:"model,omitempty"` | ||||
| 	CreatedAt       string  `json:"created_at,omitempty"` | ||||
| 	Message         Message `json:"message,omitempty"` | ||||
| 	Response        string  `json:"response,omitempty"` // for stream response | ||||
| 	Done            bool    `json:"done,omitempty"` | ||||
| 	TotalDuration   int     `json:"total_duration,omitempty"` | ||||
| 	LoadDuration    int     `json:"load_duration,omitempty"` | ||||
| 	PromptEvalCount int     `json:"prompt_eval_count,omitempty"` | ||||
| 	EvalCount       int     `json:"eval_count,omitempty"` | ||||
| 	EvalDuration    int     `json:"eval_duration,omitempty"` | ||||
| 	Error           string  `json:"error,omitempty"` | ||||
| } | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/baichuan" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/groq" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/minimax" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/mistral" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/moonshot" | ||||
| @@ -18,6 +19,7 @@ var CompatibleChannels = []int{ | ||||
| 	common.ChannelTypeMinimax, | ||||
| 	common.ChannelTypeMistral, | ||||
| 	common.ChannelTypeGroq, | ||||
| 	common.ChannelTypeLingYiWanWu, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -36,6 +38,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "mistralai", mistral.ModelList | ||||
| 	case common.ChannelTypeGroq: | ||||
| 		return "groq", groq.ModelList | ||||
| 	case common.ChannelTypeLingYiWanWu: | ||||
| 		return "lingyiwanwu", lingyiwanwu.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -15,6 +15,7 @@ const ( | ||||
| 	APITypeAIProxyLibrary | ||||
| 	APITypeTencent | ||||
| 	APITypeGemini | ||||
| 	APITypeOllama | ||||
|  | ||||
| 	APITypeDummy // this one is only for count, do not add any channel after this | ||||
| ) | ||||
| @@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int { | ||||
| 		apiType = APITypeTencent | ||||
| 	case common.ChannelTypeGemini: | ||||
| 		apiType = APITypeGemini | ||||
| 	case common.ChannelTypeOllama: | ||||
| 		apiType = APITypeOllama | ||||
| 	} | ||||
| 	return apiType | ||||
| } | ||||
|   | ||||
| @@ -22,6 +22,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
| 	ctx := c.Request.Context() | ||||
| 	audioModel := "whisper-1" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| @@ -49,16 +50,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	modelRatio := common.GetModelRatio(audioModel) | ||||
| 	groupRatio := common.GetGroupRatio(group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	var quota int | ||||
| 	var preConsumedQuota int | ||||
| 	var quota int64 | ||||
| 	var preConsumedQuota int64 | ||||
| 	switch relayMode { | ||||
| 	case constant.RelayModeAudioSpeech: | ||||
| 		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) | ||||
| 		preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) | ||||
| 		quota = preConsumedQuota | ||||
| 	default: | ||||
| 		preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) | ||||
| 		preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) | ||||
| 	} | ||||
| 	userQuota, err := model.CacheGetUserQuota(userId) | ||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, userId) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| @@ -183,7 +184,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		quota = openai.CountTokenText(text, audioModel) | ||||
| 		quota = int64(openai.CountTokenText(text, audioModel)) | ||||
| 		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||||
| 	} | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
|   | ||||
| @@ -107,18 +107,18 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int | ||||
| 	return 0 | ||||
| } | ||||
|  | ||||
| func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { | ||||
| func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { | ||||
| 	preConsumedTokens := config.PreConsumedQuota | ||||
| 	if textRequest.MaxTokens != 0 { | ||||
| 		preConsumedTokens = promptTokens + textRequest.MaxTokens | ||||
| 		preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) | ||||
| 	} | ||||
| 	return int(float64(preConsumedTokens) * ratio) | ||||
| 	return int64(float64(preConsumedTokens) * ratio) | ||||
| } | ||||
|  | ||||
| func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { | ||||
| func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { | ||||
| 	preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) | ||||
|  | ||||
| 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | ||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||
| 	if err != nil { | ||||
| 		return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) | ||||
| 	} | ||||
| @@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR | ||||
| 	return preConsumedQuota, nil | ||||
| } | ||||
|  | ||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { | ||||
| func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { | ||||
| 	if usage == nil { | ||||
| 		logger.Error(ctx, "usage is nil, which is unexpected") | ||||
| 		return | ||||
| 	} | ||||
| 	quota := 0 | ||||
| 	var quota int64 | ||||
| 	completionRatio := common.GetCompletionRatio(textRequest.Model) | ||||
| 	promptTokens := usage.PromptTokens | ||||
| 	completionTokens := usage.CompletionTokens | ||||
| 	quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
| 	quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) | ||||
| 	if ratio != 0 && quota <= 0 { | ||||
| 		quota = 1 | ||||
| 	} | ||||
| @@ -168,7 +168,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "error consuming token remain quota: "+err.Error()) | ||||
| 	} | ||||
| 	err = model.CacheUpdateUserQuota(meta.UserId) | ||||
| 	err = model.CacheUpdateUserQuota(ctx, meta.UserId) | ||||
| 	if err != nil { | ||||
| 		logger.Error(ctx, "error update user quota cache: "+err.Error()) | ||||
| 	} | ||||
|   | ||||
| @@ -79,9 +79,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	modelRatio := common.GetModelRatio(imageRequest.Model) | ||||
| 	groupRatio := common.GetGroupRatio(meta.Group) | ||||
| 	ratio := modelRatio * groupRatio | ||||
| 	userQuota, err := model.CacheGetUserQuota(meta.UserId) | ||||
| 	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) | ||||
|  | ||||
| 	quota := int(ratio*imageCostRatio*1000) * imageRequest.N | ||||
| 	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) | ||||
|  | ||||
| 	if userQuota-quota < 0 { | ||||
| 		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) | ||||
| @@ -125,7 +125,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 		} | ||||
| 		err = model.CacheUpdateUserQuota(meta.UserId) | ||||
| 		err = model.CacheUpdateUserQuota(ctx, meta.UserId) | ||||
| 		if err != nil { | ||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
|   | ||||
| @@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 		if err != nil { | ||||
| 			return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) | ||||
| 		} | ||||
| 		logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) | ||||
| 		requestBody = bytes.NewBuffer(jsonData) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/baidu" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/gemini" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/ollama" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/palm" | ||||
| 	"github.com/songquanpeng/one-api/relay/channel/tencent" | ||||
| @@ -37,6 +38,8 @@ func GetAdaptor(apiType int) channel.Adaptor { | ||||
| 		return &xunfei.Adaptor{} | ||||
| 	case constant.APITypeZhipu: | ||||
| 		return &zhipu.Adaptor{} | ||||
| 	case constant.APITypeOllama: | ||||
| 		return &ollama.Adaptor{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -6,7 +6,7 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| ) | ||||
|  | ||||
| func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { | ||||
| func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { | ||||
| 	if preConsumedQuota != 0 { | ||||
| 		go func(ctx context.Context) { | ||||
| 			// return pre-consumed quota | ||||
|   | ||||
| @@ -27,7 +27,23 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { | ||||
| 	if statusCode == http.StatusUnauthorized { | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 	switch err.Type { | ||||
| 	case "insufficient_quota": | ||||
| 		return true | ||||
| 	// https://docs.anthropic.com/claude/reference/errors | ||||
| 	case "authentication_error": | ||||
| 		return true | ||||
| 	case "permission_error": | ||||
| 		return true | ||||
| 	case "forbidden": | ||||
| 		return true | ||||
| 	} | ||||
| 	if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { | ||||
| 		return true | ||||
| 	} | ||||
| 	if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic | ||||
| 		return true | ||||
| 	} else if strings.HasPrefix(err.Message, "This organization has been disabled.") { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
| @@ -101,6 +117,9 @@ func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.Err | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if config.DebugEnabled { | ||||
| 		logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| @@ -136,20 +155,20 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin | ||||
| 	return fullRequestURL | ||||
| } | ||||
|  | ||||
| func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||
| func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { | ||||
| 	// quotaDelta is remaining quota to be consumed | ||||
| 	err := model.PostConsumeTokenQuota(tokenId, quotaDelta) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error consuming token remain quota: " + err.Error()) | ||||
| 	} | ||||
| 	err = model.CacheUpdateUserQuota(userId) | ||||
| 	err = model.CacheUpdateUserQuota(ctx, userId) | ||||
| 	if err != nil { | ||||
| 		logger.SysError("error update user quota cache: " + err.Error()) | ||||
| 	} | ||||
| 	// totalQuota is total quota consumed | ||||
| 	if totalQuota != 0 { | ||||
| 		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) | ||||
| 		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) | ||||
| 		model.UpdateChannelUsedQuota(channelId, totalQuota) | ||||
| 	} | ||||
|   | ||||
| @@ -70,7 +70,7 @@ func SetApiRouter(router *gin.Engine) { | ||||
| 			channelRoute.GET("/search", controller.SearchChannels) | ||||
| 			channelRoute.GET("/models", controller.ListModels) | ||||
| 			channelRoute.GET("/:id", controller.GetChannel) | ||||
| 			channelRoute.GET("/test", controller.TestAllChannels) | ||||
| 			channelRoute.GET("/test", controller.TestChannels) | ||||
| 			channelRoute.GET("/test/:id", controller.TestChannel) | ||||
| 			channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) | ||||
| 			channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
|  | ||||
| func SetDashboardRouter(router *gin.Engine) { | ||||
| 	apiRouter := router.Group("/") | ||||
| 	apiRouter.Use(middleware.CORS()) | ||||
| 	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||
| 	apiRouter.Use(middleware.GlobalAPIRateLimit()) | ||||
| 	apiRouter.Use(middleware.TokenAuth()) | ||||
|   | ||||
| @@ -95,6 +95,18 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 29, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   30: { | ||||
|     key: 30, | ||||
|     text: 'Ollama', | ||||
|     value: 30, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   31: { | ||||
|     key: 31, | ||||
|     text: '零一万物', | ||||
|     value: 31, | ||||
|     color: 'default' | ||||
|   }, | ||||
|   8: { | ||||
|     key: 8, | ||||
|     text: '自定义渠道', | ||||
|   | ||||
| @@ -166,6 +166,12 @@ const typeConfig = { | ||||
|   29: { | ||||
|     modelGroup: "groq", | ||||
|   }, | ||||
|   30: { | ||||
|     modelGroup: "ollama", | ||||
|   }, | ||||
|   31: { | ||||
|     modelGroup: "lingyiwanwu", | ||||
|   }, | ||||
| }; | ||||
|  | ||||
| export { defaultConfig, typeConfig }; | ||||
|   | ||||
| @@ -240,11 +240,11 @@ const ChannelsTable = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const testAllChannels = async () => { | ||||
|     const res = await API.get(`/api/channel/test`); | ||||
|   const testChannels = async (scope) => { | ||||
|     const res = await API.get(`/api/channel/test?scope=${scope}`); | ||||
|     const { success, message } = res.data; | ||||
|     if (success) { | ||||
|       showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); | ||||
|       showInfo('已成功开始测试通道,请刷新页面查看结果。'); | ||||
|     } else { | ||||
|       showError(message); | ||||
|     } | ||||
| @@ -529,9 +529,12 @@ const ChannelsTable = () => { | ||||
|               <Button size='small' as={Link} to='/channel/add' loading={loading}> | ||||
|                 添加新的渠道 | ||||
|               </Button> | ||||
|               <Button size='small' loading={loading} onClick={testAllChannels}> | ||||
|               <Button size='small' loading={loading} onClick={()=>{testChannels("all")}}> | ||||
|                 测试所有渠道 | ||||
|               </Button> | ||||
|               <Button size='small' loading={loading} onClick={()=>{testChannels("disabled")}}> | ||||
|                 测试禁用渠道 | ||||
|               </Button> | ||||
|               {/*<Button size='small' onClick={updateAllChannelsBalance}*/} | ||||
|               {/*        loading={loading || updatingBalance}>更新已启用渠道余额</Button>*/} | ||||
|               <Popup | ||||
|   | ||||
| @@ -16,6 +16,17 @@ const PasswordResetForm = () => { | ||||
|   const [disableButton, setDisableButton] = useState(false); | ||||
|   const [countdown, setCountdown] = useState(30); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let status = localStorage.getItem('status'); | ||||
|     if (status) { | ||||
|       status = JSON.parse(status); | ||||
|       if (status.turnstile_check) { | ||||
|         setTurnstileEnabled(true); | ||||
|         setTurnstileSiteKey(status.turnstile_site_key); | ||||
|       } | ||||
|     } | ||||
|   }, []); | ||||
|  | ||||
|   useEffect(() => { | ||||
|     let countdownInterval = null; | ||||
|     if (disableButton && countdown > 0) { | ||||
|   | ||||
| @@ -22,6 +22,8 @@ const SystemSetting = () => { | ||||
|     WeChatServerAddress: '', | ||||
|     WeChatServerToken: '', | ||||
|     WeChatAccountQRCodeImageURL: '', | ||||
|     MessagePusherAddress: '', | ||||
|     MessagePusherToken: '', | ||||
|     TurnstileCheckEnabled: '', | ||||
|     TurnstileSiteKey: '', | ||||
|     TurnstileSecretKey: '', | ||||
| @@ -183,6 +185,21 @@ const SystemSetting = () => { | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitMessagePusher = async () => { | ||||
|     if (originInputs['MessagePusherAddress'] !== inputs.MessagePusherAddress) { | ||||
|       await updateOption( | ||||
|         'MessagePusherAddress', | ||||
|         removeTrailingSlash(inputs.MessagePusherAddress) | ||||
|       ); | ||||
|     } | ||||
|     if ( | ||||
|       originInputs['MessagePusherToken'] !== inputs.MessagePusherToken && | ||||
|       inputs.MessagePusherToken !== '' | ||||
|     ) { | ||||
|       await updateOption('MessagePusherToken', inputs.MessagePusherToken); | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   const submitGitHubOAuth = async () => { | ||||
|     if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { | ||||
|       await updateOption('GitHubClientId', inputs.GitHubClientId); | ||||
| @@ -496,6 +513,42 @@ const SystemSetting = () => { | ||||
|             保存 WeChat Server 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 Message Pusher | ||||
|             <Header.Subheader> | ||||
|               用以推送报警信息, | ||||
|               <a | ||||
|                 href='https://github.com/songquanpeng/message-pusher' | ||||
|                 target='_blank' | ||||
|               > | ||||
|                 点击此处 | ||||
|               </a> | ||||
|               了解 Message Pusher | ||||
|             </Header.Subheader> | ||||
|           </Header> | ||||
|           <Form.Group widths={3}> | ||||
|             <Form.Input | ||||
|               label='Message Pusher 推送地址' | ||||
|               name='MessagePusherAddress' | ||||
|               placeholder='例如:https://msgpusher.com/push/your_username' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.MessagePusherAddress} | ||||
|             /> | ||||
|             <Form.Input | ||||
|               label='Message Pusher 访问凭证' | ||||
|               name='MessagePusherToken' | ||||
|               type='password' | ||||
|               onChange={handleInputChange} | ||||
|               autoComplete='new-password' | ||||
|               value={inputs.MessagePusherToken} | ||||
|               placeholder='敏感信息不会发送到前端显示' | ||||
|             /> | ||||
|           </Form.Group> | ||||
|           <Form.Button onClick={submitMessagePusher}> | ||||
|             保存 Message Pusher 设置 | ||||
|           </Form.Button> | ||||
|           <Divider /> | ||||
|           <Header as='h3'> | ||||
|             配置 Turnstile | ||||
|             <Header.Subheader> | ||||
|   | ||||
| @@ -15,6 +15,8 @@ export const CHANNEL_OPTIONS = [ | ||||
|   { key: 26, text: '百川大模型', value: 26, color: 'orange' }, | ||||
|   { key: 27, text: 'MiniMax', value: 27, color: 'red' }, | ||||
|   { key: 29, text: 'Groq', value: 29, color: 'orange' }, | ||||
|   { key: 30, text: 'Ollama', value: 30, color: 'black' }, | ||||
|   { key: 31, text: '零一万物', value: 31, color: 'green' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user