mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			19 Commits
		
	
	
		
			v0.6.2-alp
			...
			v0.6.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					b204f6d82b | ||
| 
						 | 
					752639560f | ||
| 
						 | 
					996f4d99dd | ||
| 
						 | 
					ebfee3b46c | ||
| 
						 | 
					3e2e805d61 | ||
| 
						 | 
					3edf7247c4 | ||
| 
						 | 
					0926b6206b | ||
| 
						 | 
					7cd57f3125 | ||
| 
						 | 
					66efabd5ae | ||
| 
						 | 
					8ede66a896 | ||
| 
						 | 
					b169173860 | ||
| 
						 | 
					f33555ae78 | ||
| 
						 | 
					c28ec10795 | ||
| 
						 | 
					e3767cbb07 | ||
| 
						 | 
					be9eb59fbb | ||
| 
						 | 
					89e111ac69 | ||
| 
						 | 
					2dcef85285 | ||
| 
						 | 
					79d0cd378a | ||
| 
						 | 
					e99150bdb9 | 
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64-en.yml
									
									
									
									
										vendored
									
									
								
							@@ -20,6 +20,13 @@ jobs:
 | 
			
		||||
      - name: Check out the repo
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi      
 | 
			
		||||
 | 
			
		||||
      - name: Save version info
 | 
			
		||||
        run: |
 | 
			
		||||
          git describe --tags > VERSION 
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-amd64.yml
									
									
									
									
										vendored
									
									
								
							@@ -20,6 +20,13 @@ jobs:
 | 
			
		||||
      - name: Check out the repo
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi        
 | 
			
		||||
 | 
			
		||||
      - name: Save version info
 | 
			
		||||
        run: |
 | 
			
		||||
          git describe --tags > VERSION 
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/workflows/docker-image-arm64.yml
									
									
									
									
										vendored
									
									
								
							@@ -21,6 +21,13 @@ jobs:
 | 
			
		||||
      - name: Check out the repo
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: Save version info
 | 
			
		||||
        run: |
 | 
			
		||||
          git describe --tags > VERSION 
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/linux-release.yml
									
									
									
									
										vendored
									
									
								
							@@ -20,6 +20,12 @@ jobs:
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
        with:
 | 
			
		||||
          fetch-depth: 0
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
      - uses: actions/setup-node@v3
 | 
			
		||||
        with:
 | 
			
		||||
          node-version: 16
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/macos-release.yml
									
									
									
									
										vendored
									
									
								
							@@ -20,6 +20,12 @@ jobs:
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
        with:
 | 
			
		||||
          fetch-depth: 0
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
      - uses: actions/setup-node@v3
 | 
			
		||||
        with:
 | 
			
		||||
          node-version: 16
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/windows-release.yml
									
									
									
									
										vendored
									
									
								
							@@ -23,6 +23,12 @@ jobs:
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
        with:
 | 
			
		||||
          fetch-depth: 0
 | 
			
		||||
      - name: Check repository URL
 | 
			
		||||
        run: |
 | 
			
		||||
          REPO_URL=$(git config --get remote.origin.url)
 | 
			
		||||
          if [[ $REPO_URL == *"pro" ]]; then
 | 
			
		||||
            exit 0
 | 
			
		||||
          fi
 | 
			
		||||
      - uses: actions/setup-node@v3
 | 
			
		||||
        with:
 | 
			
		||||
          node-version: 16
 | 
			
		||||
 
 | 
			
		||||
@@ -79,6 +79,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
 | 
			
		||||
   + [x] [MINIMAX](https://api.minimax.chat/)
 | 
			
		||||
   + [x] [Groq](https://wow.groq.com/)
 | 
			
		||||
   + [x] [Ollama](https://github.com/ollama/ollama)
 | 
			
		||||
   + [x] [零一万物](https://platform.lingyiwanwu.com/)
 | 
			
		||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
 | 
			
		||||
3. 支持通过**负载均衡**的方式访问多个渠道。
 | 
			
		||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
package config
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/env"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
@@ -76,14 +76,14 @@ var MessagePusherToken = ""
 | 
			
		||||
var TurnstileSiteKey = ""
 | 
			
		||||
var TurnstileSecretKey = ""
 | 
			
		||||
 | 
			
		||||
var QuotaForNewUser = 0
 | 
			
		||||
var QuotaForInviter = 0
 | 
			
		||||
var QuotaForInvitee = 0
 | 
			
		||||
var QuotaForNewUser int64 = 0
 | 
			
		||||
var QuotaForInviter int64 = 0
 | 
			
		||||
var QuotaForInvitee int64 = 0
 | 
			
		||||
var ChannelDisableThreshold = 5.0
 | 
			
		||||
var AutomaticDisableChannelEnabled = false
 | 
			
		||||
var AutomaticEnableChannelEnabled = false
 | 
			
		||||
var QuotaRemindThreshold = 1000
 | 
			
		||||
var PreConsumedQuota = 500
 | 
			
		||||
var QuotaRemindThreshold int64 = 1000
 | 
			
		||||
var PreConsumedQuota int64 = 500
 | 
			
		||||
var ApproximateTokenEnabled = false
 | 
			
		||||
var RetryTimes = 0
 | 
			
		||||
 | 
			
		||||
@@ -94,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 | 
			
		||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
			
		||||
var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
			
		||||
 | 
			
		||||
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
 | 
			
		||||
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
 | 
			
		||||
 | 
			
		||||
var BatchUpdateEnabled = false
 | 
			
		||||
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
 | 
			
		||||
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
 | 
			
		||||
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
 | 
			
		||||
 | 
			
		||||
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
 | 
			
		||||
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
 | 
			
		||||
 | 
			
		||||
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
 | 
			
		||||
var Theme = env.String("THEME", "default")
 | 
			
		||||
var ValidThemes = map[string]bool{
 | 
			
		||||
	"default": true,
 | 
			
		||||
	"berry":   true,
 | 
			
		||||
@@ -112,10 +112,10 @@ var ValidThemes = map[string]bool{
 | 
			
		||||
// All duration's unit is seconds
 | 
			
		||||
// Shouldn't larger then RateLimitKeyExpirationDuration
 | 
			
		||||
var (
 | 
			
		||||
	GlobalApiRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
 | 
			
		||||
	GlobalApiRateLimitNum            = env.Int("GLOBAL_API_RATE_LIMIT", 180)
 | 
			
		||||
	GlobalApiRateLimitDuration int64 = 3 * 60
 | 
			
		||||
 | 
			
		||||
	GlobalWebRateLimitNum            = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
 | 
			
		||||
	GlobalWebRateLimitNum            = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
 | 
			
		||||
	GlobalWebRateLimitDuration int64 = 3 * 60
 | 
			
		||||
 | 
			
		||||
	UploadRateLimitNum            = 10
 | 
			
		||||
@@ -130,8 +130,8 @@ var (
 | 
			
		||||
 | 
			
		||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
 | 
			
		||||
 | 
			
		||||
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false)
 | 
			
		||||
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10)
 | 
			
		||||
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
 | 
			
		||||
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024)
 | 
			
		||||
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128)
 | 
			
		||||
var EnableMetric = env.Bool("ENABLE_METRIC", false)
 | 
			
		||||
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
 | 
			
		||||
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
 | 
			
		||||
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
 | 
			
		||||
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
 | 
			
		||||
 
 | 
			
		||||
@@ -69,6 +69,8 @@ const (
 | 
			
		||||
	ChannelTypeMinimax
 | 
			
		||||
	ChannelTypeMistral
 | 
			
		||||
	ChannelTypeGroq
 | 
			
		||||
	ChannelTypeOllama
 | 
			
		||||
	ChannelTypeLingYiWanWu
 | 
			
		||||
 | 
			
		||||
	ChannelTypeDummy
 | 
			
		||||
)
 | 
			
		||||
@@ -104,6 +106,8 @@ var ChannelBaseURLs = []string{
 | 
			
		||||
	"https://api.minimax.chat",                  // 27
 | 
			
		||||
	"https://api.mistral.ai",                    // 28
 | 
			
		||||
	"https://api.groq.com/openai",               // 29
 | 
			
		||||
	"http://localhost:11434",                    // 30
 | 
			
		||||
	"https://api.lingyiwanwu.com",               // 31
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,12 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import "github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/env"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var UsingSQLite = false
 | 
			
		||||
var UsingPostgreSQL = false
 | 
			
		||||
var UsingMySQL = false
 | 
			
		||||
 | 
			
		||||
var SQLitePath = "one-api.db"
 | 
			
		||||
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
 | 
			
		||||
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								common/env/helper.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
			
		||||
package env
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Bool(env string, defaultValue bool) bool {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return os.Getenv(env) == "true"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Int(env string, defaultValue int) int {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	num, err := strconv.Atoi(os.Getenv(env))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Float64(env string, defaultValue float64) float64 {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	num, err := strconv.ParseFloat(os.Getenv(env), 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func String(env string, defaultValue string) string {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return os.Getenv(env)
 | 
			
		||||
}
 | 
			
		||||
@@ -3,12 +3,10 @@ package helper
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -187,6 +185,10 @@ func GetTimeString() string {
 | 
			
		||||
	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenRequestID() string {
 | 
			
		||||
	return GetTimeString() + GetRandomNumberString(8)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Max(a int, b int) int {
 | 
			
		||||
	if a >= b {
 | 
			
		||||
		return a
 | 
			
		||||
@@ -195,44 +197,6 @@ func Max(a int, b int) int {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return os.Getenv(env) == "true"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetOrDefaultEnvInt(env string, defaultValue int) int {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	num, err := strconv.Atoi(os.Getenv(env))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	num, err := strconv.ParseFloat(os.Getenv(env), 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return num
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetOrDefaultEnvString(env string, defaultValue string) string {
 | 
			
		||||
	if env == "" || os.Getenv(env) == "" {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return os.Getenv(env)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AssignOrDefault(value string, defaultValue string) string {
 | 
			
		||||
	if len(value) != 0 {
 | 
			
		||||
		return value
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,8 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -54,7 +56,9 @@ func SysError(s string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Debug(ctx context.Context, msg string) {
 | 
			
		||||
	logHelper(ctx, loggerDEBUG, msg)
 | 
			
		||||
	if config.DebugEnabled {
 | 
			
		||||
		logHelper(ctx, loggerDEBUG, msg)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Info(ctx context.Context, msg string) {
 | 
			
		||||
@@ -91,6 +95,9 @@ func logHelper(ctx context.Context, level string, msg string) {
 | 
			
		||||
		writer = gin.DefaultWriter
 | 
			
		||||
	}
 | 
			
		||||
	id := ctx.Value(RequestIdKey)
 | 
			
		||||
	if id == nil {
 | 
			
		||||
		id = helper.GenRequestID()
 | 
			
		||||
	}
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
 | 
			
		||||
	if !setupLogWorking {
 | 
			
		||||
 
 | 
			
		||||
@@ -69,7 +69,7 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"claude-instant-1.2":       0.8 / 1000 * USD,
 | 
			
		||||
	"claude-2.0":               8.0 / 1000 * USD,
 | 
			
		||||
	"claude-2.1":               8.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-haiku-20240229":  0.25 / 1000 * USD,
 | 
			
		||||
	"claude-3-haiku-20240307":  0.25 / 1000 * USD,
 | 
			
		||||
	"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
 | 
			
		||||
	"claude-3-opus-20240229":   15.0 / 1000 * USD,
 | 
			
		||||
	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
 | 
			
		||||
@@ -78,6 +78,9 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"ERNIE-Bot-4":       0.12 * RMB, // ¥0.12 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-8k":      0.024 * RMB,
 | 
			
		||||
	"Embedding-V1":      0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"bge-large-zh":      0.002 * RMB,
 | 
			
		||||
	"bge-large-en":      0.002 * RMB,
 | 
			
		||||
	"bge-large-8k":      0.002 * RMB,
 | 
			
		||||
	"PaLM-2":            1,
 | 
			
		||||
	"gemini-pro":        1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
@@ -130,6 +133,10 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"llama2-7b-2048":     0.1 / 1000 * USD,
 | 
			
		||||
	"mixtral-8x7b-32768": 0.27 / 1000 * USD,
 | 
			
		||||
	"gemma-7b-it":        0.1 / 1000 * USD,
 | 
			
		||||
	// https://platform.lingyiwanwu.com/docs#-计费单元
 | 
			
		||||
	"yi-34b-chat-0205": 2.5 / 1000000 * RMB,
 | 
			
		||||
	"yi-34b-chat-200k": 12.0 / 1000000 * RMB,
 | 
			
		||||
	"yi-vl-plus":       6.0 / 1000000 * RMB,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var CompletionRatio = map[string]float64{}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func LogQuota(quota int) string {
 | 
			
		||||
func LogQuota(quota int64) string {
 | 
			
		||||
	if config.DisplayInCurrencyEnabled {
 | 
			
		||||
		return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -8,8 +8,8 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetSubscription(c *gin.Context) {
 | 
			
		||||
	var remainQuota int
 | 
			
		||||
	var usedQuota int
 | 
			
		||||
	var remainQuota int64
 | 
			
		||||
	var usedQuota int64
 | 
			
		||||
	var err error
 | 
			
		||||
	var token *model.Token
 | 
			
		||||
	var expiredTime int64
 | 
			
		||||
@@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUsage(c *gin.Context) {
 | 
			
		||||
	var quota int
 | 
			
		||||
	var quota int64
 | 
			
		||||
	var err error
 | 
			
		||||
	var token *model.Token
 | 
			
		||||
	if config.DisplayTokenStatEnabled {
 | 
			
		||||
 
 | 
			
		||||
@@ -30,7 +30,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
 | 
			
		||||
	testRequest := &relaymodel.GeneralOpenAIRequest{
 | 
			
		||||
		MaxTokens: 1,
 | 
			
		||||
		MaxTokens: 2,
 | 
			
		||||
		Stream:    false,
 | 
			
		||||
		Model:     "gpt-3.5-turbo",
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ version: '3.4'
 | 
			
		||||
 | 
			
		||||
services:
 | 
			
		||||
  one-api:
 | 
			
		||||
    image: justsong/one-api:latest
 | 
			
		||||
    image: "${REGISTRY:-docker.io}/justsong/one-api:latest"
 | 
			
		||||
    container_name: one-api
 | 
			
		||||
    restart: always
 | 
			
		||||
    command: --log-dir /app/logs
 | 
			
		||||
@@ -29,12 +29,12 @@ services:
 | 
			
		||||
      retries: 3
 | 
			
		||||
 | 
			
		||||
  redis:
 | 
			
		||||
    image: redis:latest
 | 
			
		||||
    image: "${REGISTRY:-docker.io}/redis:latest"
 | 
			
		||||
    container_name: redis
 | 
			
		||||
    restart: always
 | 
			
		||||
 | 
			
		||||
  db:
 | 
			
		||||
    image: mysql:8.2.0
 | 
			
		||||
    image: "${REGISTRY:-docker.io}/mysql:8.2.0"
 | 
			
		||||
    restart: always
 | 
			
		||||
    container_name: mysql
 | 
			
		||||
    volumes:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -60,6 +60,6 @@ require (
 | 
			
		||||
	golang.org/x/net v0.17.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.15.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.14.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.30.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.33.0 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@@ -177,8 +177,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 | 
			
		||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 | 
			
		||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
 | 
			
		||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 | 
			
		||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
 | 
			
		||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
 | 
			
		||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								main.go
									
									
									
									
									
								
							@@ -30,11 +30,25 @@ func main() {
 | 
			
		||||
	if config.DebugEnabled {
 | 
			
		||||
		logger.SysLog("running in debug mode")
 | 
			
		||||
	}
 | 
			
		||||
	var err error
 | 
			
		||||
	// Initialize SQL Database
 | 
			
		||||
	err := model.InitDB()
 | 
			
		||||
	model.DB, err = model.InitDB("SQL_DSN")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to initialize database: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	if os.Getenv("LOG_SQL_DSN") != "" {
 | 
			
		||||
		logger.SysLog("using secondary database for table logs")
 | 
			
		||||
		model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.FatalLog("failed to initialize secondary database: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		model.LOG_DB = model.DB
 | 
			
		||||
	}
 | 
			
		||||
	err = model.CreateRootAccountIfNeed()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("database init error: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		err := model.CloseDB()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func RequestId() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		id := helper.GetTimeString() + helper.GetRandomNumberString(8)
 | 
			
		||||
		id := helper.GenRequestID()
 | 
			
		||||
		c.Set(logger.RequestIdKey, id)
 | 
			
		||||
		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
 | 
			
		||||
		c.Request = c.Request.WithContext(ctx)
 | 
			
		||||
 
 | 
			
		||||
@@ -71,7 +71,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
 | 
			
		||||
	return group, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int, err error) {
 | 
			
		||||
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
 | 
			
		||||
	quota, err = GetUserQuota(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
@@ -83,7 +83,7 @@ func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int, err error)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CacheGetUserQuota(ctx context.Context, id int) (quota int, err error) {
 | 
			
		||||
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
		return GetUserQuota(id)
 | 
			
		||||
	}
 | 
			
		||||
@@ -91,7 +91,7 @@ func CacheGetUserQuota(ctx context.Context, id int) (quota int, err error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fetchAndUpdateUserQuota(ctx, id)
 | 
			
		||||
	}
 | 
			
		||||
	quota, err = strconv.Atoi(quotaString)
 | 
			
		||||
	quota, err = strconv.ParseInt(quotaString, 10, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, nil
 | 
			
		||||
	}
 | 
			
		||||
@@ -114,7 +114,7 @@ func CacheUpdateUserQuota(ctx context.Context, id int) error {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CacheDecreaseUserQuota(id int, quota int) error {
 | 
			
		||||
func CacheDecreaseUserQuota(id int, quota int64) error {
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -178,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
func UpdateChannelUsedQuota(id int, quota int64) {
 | 
			
		||||
	if config.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
 | 
			
		||||
		return
 | 
			
		||||
@@ -186,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
	updateChannelUsedQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
func updateChannelUsedQuota(id int, quota int64) {
 | 
			
		||||
	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("failed to update channel used quota: " + err.Error())
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										30
									
								
								model/log.go
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								model/log.go
									
									
									
									
									
								
							@@ -45,13 +45,13 @@ func RecordLog(userId int, logType int, content string) {
 | 
			
		||||
		Type:      logType,
 | 
			
		||||
		Content:   content,
 | 
			
		||||
	}
 | 
			
		||||
	err := DB.Create(log).Error
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("failed to record log: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
 | 
			
		||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
 | 
			
		||||
	logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 | 
			
		||||
	if !config.LogConsumeEnabled {
 | 
			
		||||
		return
 | 
			
		||||
@@ -66,10 +66,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
 | 
			
		||||
		CompletionTokens: completionTokens,
 | 
			
		||||
		TokenName:        tokenName,
 | 
			
		||||
		ModelName:        modelName,
 | 
			
		||||
		Quota:            quota,
 | 
			
		||||
		Quota:            int(quota),
 | 
			
		||||
		ChannelId:        channelId,
 | 
			
		||||
	}
 | 
			
		||||
	err := DB.Create(log).Error
 | 
			
		||||
	err := LOG_DB.Create(log).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(ctx, "failed to record log: "+err.Error())
 | 
			
		||||
	}
 | 
			
		||||
@@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
 | 
			
		||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 | 
			
		||||
	var tx *gorm.DB
 | 
			
		||||
	if logType == LogTypeUnknown {
 | 
			
		||||
		tx = DB
 | 
			
		||||
		tx = LOG_DB
 | 
			
		||||
	} else {
 | 
			
		||||
		tx = DB.Where("type = ?", logType)
 | 
			
		||||
		tx = LOG_DB.Where("type = ?", logType)
 | 
			
		||||
	}
 | 
			
		||||
	if modelName != "" {
 | 
			
		||||
		tx = tx.Where("model_name = ?", modelName)
 | 
			
		||||
@@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 | 
			
		||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
 | 
			
		||||
	var tx *gorm.DB
 | 
			
		||||
	if logType == LogTypeUnknown {
 | 
			
		||||
		tx = DB.Where("user_id = ?", userId)
 | 
			
		||||
		tx = LOG_DB.Where("user_id = ?", userId)
 | 
			
		||||
	} else {
 | 
			
		||||
		tx = DB.Where("user_id = ? and type = ?", userId, logType)
 | 
			
		||||
		tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
 | 
			
		||||
	}
 | 
			
		||||
	if modelName != "" {
 | 
			
		||||
		tx = tx.Where("model_name = ?", modelName)
 | 
			
		||||
@@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
 | 
			
		||||
	err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
 | 
			
		||||
	err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 | 
			
		||||
	err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
 | 
			
		||||
	err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
 | 
			
		||||
	return logs, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
 | 
			
		||||
	tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
 | 
			
		||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
 | 
			
		||||
	tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
	}
 | 
			
		||||
@@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
 | 
			
		||||
	tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
 | 
			
		||||
	tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		tx = tx.Where("username = ?", username)
 | 
			
		||||
	}
 | 
			
		||||
@@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
 | 
			
		||||
	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
 | 
			
		||||
	result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
 | 
			
		||||
	return result.RowsAffected, result.Error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis
 | 
			
		||||
		groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = DB.Raw(`
 | 
			
		||||
	err = LOG_DB.Raw(`
 | 
			
		||||
		SELECT `+groupSelect+`,
 | 
			
		||||
		model_name, count(1) as request_count,
 | 
			
		||||
		sum(quota) as quota,
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/env"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"gorm.io/driver/mysql"
 | 
			
		||||
@@ -16,8 +17,9 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var DB *gorm.DB
 | 
			
		||||
var LOG_DB *gorm.DB
 | 
			
		||||
 | 
			
		||||
func createRootAccountIfNeed() error {
 | 
			
		||||
func CreateRootAccountIfNeed() error {
 | 
			
		||||
	var user User
 | 
			
		||||
	//if user.Status != util.UserStatusEnabled {
 | 
			
		||||
	if err := DB.First(&user).Error; err != nil {
 | 
			
		||||
@@ -40,9 +42,9 @@ func createRootAccountIfNeed() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func chooseDB() (*gorm.DB, error) {
 | 
			
		||||
	if os.Getenv("SQL_DSN") != "" {
 | 
			
		||||
		dsn := os.Getenv("SQL_DSN")
 | 
			
		||||
func chooseDB(envName string) (*gorm.DB, error) {
 | 
			
		||||
	if os.Getenv(envName) != "" {
 | 
			
		||||
		dsn := os.Getenv(envName)
 | 
			
		||||
		if strings.HasPrefix(dsn, "postgres://") {
 | 
			
		||||
			// Use PostgreSQL
 | 
			
		||||
			logger.SysLog("using PostgreSQL as database")
 | 
			
		||||
@@ -70,23 +72,22 @@ func chooseDB() (*gorm.DB, error) {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitDB() (err error) {
 | 
			
		||||
	db, err := chooseDB()
 | 
			
		||||
func InitDB(envName string) (db *gorm.DB, err error) {
 | 
			
		||||
	db, err = chooseDB(envName)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		if config.DebugSQLEnabled {
 | 
			
		||||
			db = db.Debug()
 | 
			
		||||
		}
 | 
			
		||||
		DB = db
 | 
			
		||||
		sqlDB, err := DB.DB()
 | 
			
		||||
		sqlDB, err := db.DB()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
 | 
			
		||||
		sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
 | 
			
		||||
		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
 | 
			
		||||
		sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
 | 
			
		||||
		sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
 | 
			
		||||
		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
 | 
			
		||||
 | 
			
		||||
		if !config.IsMasterNode {
 | 
			
		||||
			return nil
 | 
			
		||||
			return db, err
 | 
			
		||||
		}
 | 
			
		||||
		if common.UsingMySQL {
 | 
			
		||||
			_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
 | 
			
		||||
@@ -94,46 +95,55 @@ func InitDB() (err error) {
 | 
			
		||||
		logger.SysLog("database migration started")
 | 
			
		||||
		err = db.AutoMigrate(&Channel{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Token{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&User{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Option{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Redemption{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Ability{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = db.AutoMigrate(&Log{})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		logger.SysLog("database migrated")
 | 
			
		||||
		err = createRootAccountIfNeed()
 | 
			
		||||
		return err
 | 
			
		||||
		return db, err
 | 
			
		||||
	} else {
 | 
			
		||||
		logger.FatalLog(err)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
	return db, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CloseDB() error {
 | 
			
		||||
	sqlDB, err := DB.DB()
 | 
			
		||||
func closeDB(db *gorm.DB) error {
 | 
			
		||||
	sqlDB, err := db.DB()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = sqlDB.Close()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CloseDB() error {
 | 
			
		||||
	if LOG_DB != DB {
 | 
			
		||||
		err := closeDB(LOG_DB)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return closeDB(DB)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -61,11 +61,11 @@ func InitOptionMap() {
 | 
			
		||||
	config.OptionMap["MessagePusherToken"] = ""
 | 
			
		||||
	config.OptionMap["TurnstileSiteKey"] = ""
 | 
			
		||||
	config.OptionMap["TurnstileSecretKey"] = ""
 | 
			
		||||
	config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
 | 
			
		||||
	config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
 | 
			
		||||
	config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
 | 
			
		||||
	config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
 | 
			
		||||
	config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
 | 
			
		||||
	config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
 | 
			
		||||
	config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
 | 
			
		||||
	config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
 | 
			
		||||
	config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
 | 
			
		||||
	config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
 | 
			
		||||
	config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 | 
			
		||||
	config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
 | 
			
		||||
	config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
 | 
			
		||||
@@ -193,15 +193,15 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
	case "TurnstileSecretKey":
 | 
			
		||||
		config.TurnstileSecretKey = value
 | 
			
		||||
	case "QuotaForNewUser":
 | 
			
		||||
		config.QuotaForNewUser, _ = strconv.Atoi(value)
 | 
			
		||||
		config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
 | 
			
		||||
	case "QuotaForInviter":
 | 
			
		||||
		config.QuotaForInviter, _ = strconv.Atoi(value)
 | 
			
		||||
		config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
 | 
			
		||||
	case "QuotaForInvitee":
 | 
			
		||||
		config.QuotaForInvitee, _ = strconv.Atoi(value)
 | 
			
		||||
		config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
 | 
			
		||||
	case "QuotaRemindThreshold":
 | 
			
		||||
		config.QuotaRemindThreshold, _ = strconv.Atoi(value)
 | 
			
		||||
		config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
 | 
			
		||||
	case "PreConsumedQuota":
 | 
			
		||||
		config.PreConsumedQuota, _ = strconv.Atoi(value)
 | 
			
		||||
		config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
 | 
			
		||||
	case "RetryTimes":
 | 
			
		||||
		config.RetryTimes, _ = strconv.Atoi(value)
 | 
			
		||||
	case "ModelRatio":
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ type Redemption struct {
 | 
			
		||||
	Key          string `json:"key" gorm:"type:char(32);uniqueIndex"`
 | 
			
		||||
	Status       int    `json:"status" gorm:"default:1"`
 | 
			
		||||
	Name         string `json:"name" gorm:"index"`
 | 
			
		||||
	Quota        int    `json:"quota" gorm:"default:100"`
 | 
			
		||||
	Quota        int64  `json:"quota" gorm:"default:100"`
 | 
			
		||||
	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	RedeemedTime int64  `json:"redeemed_time" gorm:"bigint"`
 | 
			
		||||
	Count        int    `json:"count" gorm:"-:all"` // only for api request
 | 
			
		||||
@@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
 | 
			
		||||
	return &redemption, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Redeem(key string, userId int) (quota int, err error) {
 | 
			
		||||
func Redeem(key string, userId int) (quota int64, err error) {
 | 
			
		||||
	if key == "" {
 | 
			
		||||
		return 0, errors.New("未提供兑换码")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,9 +20,9 @@ type Token struct {
 | 
			
		||||
	CreatedTime    int64  `json:"created_time" gorm:"bigint"`
 | 
			
		||||
	AccessedTime   int64  `json:"accessed_time" gorm:"bigint"`
 | 
			
		||||
	ExpiredTime    int64  `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
 | 
			
		||||
	RemainQuota    int    `json:"remain_quota" gorm:"default:0"`
 | 
			
		||||
	RemainQuota    int64  `json:"remain_quota" gorm:"default:0"`
 | 
			
		||||
	UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"`
 | 
			
		||||
	UsedQuota      int    `json:"used_quota" gorm:"default:0"` // used quota
 | 
			
		||||
	UsedQuota      int64  `json:"used_quota" gorm:"default:0"` // used quota
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
 | 
			
		||||
@@ -138,7 +138,7 @@ func DeleteTokenById(id int, userId int) (err error) {
 | 
			
		||||
	return token.Delete()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
func IncreaseTokenQuota(id int, quota int64) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
@@ -149,7 +149,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	return increaseTokenQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func increaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
func increaseTokenQuota(id int, quota int64) (err error) {
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"remain_quota":  gorm.Expr("remain_quota + ?", quota),
 | 
			
		||||
@@ -160,7 +160,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
func DecreaseTokenQuota(id int, quota int64) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
@@ -171,7 +171,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	return decreaseTokenQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
func decreaseTokenQuota(id int, quota int64) (err error) {
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"remain_quota":  gorm.Expr("remain_quota - ?", quota),
 | 
			
		||||
@@ -182,7 +182,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
@@ -232,7 +232,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
 | 
			
		||||
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
 | 
			
		||||
	token, err := GetTokenById(tokenId)
 | 
			
		||||
	if quota > 0 {
 | 
			
		||||
		err = DecreaseUserQuota(token.UserId, quota)
 | 
			
		||||
 
 | 
			
		||||
@@ -26,8 +26,8 @@ type User struct {
 | 
			
		||||
	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
 | 
			
		||||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
			
		||||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
			
		||||
	Quota            int    `json:"quota" gorm:"type:int;default:0"`
 | 
			
		||||
	UsedQuota        int    `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
 | 
			
		||||
	Quota            int64  `json:"quota" gorm:"type:int;default:0"`
 | 
			
		||||
	UsedQuota        int64  `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
 | 
			
		||||
	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`               // request number
 | 
			
		||||
	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
			
		||||
	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
 | 
			
		||||
@@ -274,12 +274,12 @@ func ValidateAccessToken(token string) (user *User) {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserQuota(id int) (quota int, err error) {
 | 
			
		||||
func GetUserQuota(id int) (quota int64, err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
 | 
			
		||||
	return quota, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserUsedQuota(id int) (quota int, err error) {
 | 
			
		||||
func GetUserUsedQuota(id int) (quota int64, err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
 | 
			
		||||
	return quota, err
 | 
			
		||||
}
 | 
			
		||||
@@ -299,7 +299,7 @@ func GetUserGroup(id int) (group string, err error) {
 | 
			
		||||
	return group, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IncreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
func IncreaseUserQuota(id int, quota int64) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
@@ -310,12 +310,12 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	return increaseUserQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func increaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
func increaseUserQuota(id int, quota int64) (err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DecreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
func DecreaseUserQuota(id int, quota int64) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
@@ -326,7 +326,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	return decreaseUserQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
func decreaseUserQuota(id int, quota int64) (err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
@@ -336,7 +336,7 @@ func GetRootUserEmail() (email string) {
 | 
			
		||||
	return email
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
			
		||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
 | 
			
		||||
	if config.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
 | 
			
		||||
		addNewRecord(BatchUpdateTypeRequestCount, id, 1)
 | 
			
		||||
@@ -345,7 +345,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
			
		||||
	updateUserUsedQuotaAndRequestCount(id, quota, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 | 
			
		||||
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
 | 
			
		||||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
@@ -357,7 +357,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUserUsedQuota(id int, quota int) {
 | 
			
		||||
func updateUserUsedQuota(id int, quota int64) {
 | 
			
		||||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"used_quota": gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
 
 | 
			
		||||
@@ -16,12 +16,12 @@ const (
 | 
			
		||||
	BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var batchUpdateStores []map[int]int
 | 
			
		||||
var batchUpdateStores []map[int]int64
 | 
			
		||||
var batchUpdateLocks []sync.Mutex
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
			
		||||
		batchUpdateStores = append(batchUpdateStores, make(map[int]int))
 | 
			
		||||
		batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
 | 
			
		||||
		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -35,7 +35,7 @@ func InitBatchUpdater() {
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addNewRecord(type_ int, id int, value int) {
 | 
			
		||||
func addNewRecord(type_ int, id int, value int64) {
 | 
			
		||||
	batchUpdateLocks[type_].Lock()
 | 
			
		||||
	defer batchUpdateLocks[type_].Unlock()
 | 
			
		||||
	if _, ok := batchUpdateStores[type_][id]; !ok {
 | 
			
		||||
@@ -50,7 +50,7 @@ func batchUpdate() {
 | 
			
		||||
	for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
			
		||||
		batchUpdateLocks[i].Lock()
 | 
			
		||||
		store := batchUpdateStores[i]
 | 
			
		||||
		batchUpdateStores[i] = make(map[int]int)
 | 
			
		||||
		batchUpdateStores[i] = make(map[int]int64)
 | 
			
		||||
		batchUpdateLocks[i].Unlock()
 | 
			
		||||
		// TODO: maybe we can combine updates with same key?
 | 
			
		||||
		for key, value := range store {
 | 
			
		||||
@@ -68,7 +68,7 @@ func batchUpdate() {
 | 
			
		||||
			case BatchUpdateTypeUsedQuota:
 | 
			
		||||
				updateUserUsedQuota(key, value)
 | 
			
		||||
			case BatchUpdateTypeRequestCount:
 | 
			
		||||
				updateUserRequestCount(key, value)
 | 
			
		||||
				updateUserRequestCount(key, int(value))
 | 
			
		||||
			case BatchUpdateTypeChannelUsedQuota:
 | 
			
		||||
				updateChannelUsedQuota(key, value)
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,6 +32,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
 | 
			
		||||
	channel.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		req.Header.Set("Accept", "text/event-stream")
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+meta.APIKey)
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ package anthropic
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"claude-instant-1.2", "claude-2.0", "claude-2.1",
 | 
			
		||||
	"claude-3-haiku-20240229",
 | 
			
		||||
	"claude-3-haiku-20240307",
 | 
			
		||||
	"claude-3-sonnet-20240229",
 | 
			
		||||
	"claude-3-opus-20240229",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,14 +3,15 @@ package baidu
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/util"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
@@ -23,7 +24,13 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
	// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
 | 
			
		||||
	suffix := "chat/"
 | 
			
		||||
	if strings.HasPrefix("Embedding", meta.ActualModelName) {
 | 
			
		||||
	if strings.HasPrefix(meta.ActualModelName, "Embedding") {
 | 
			
		||||
		suffix = "embeddings/"
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(meta.ActualModelName, "bge-large") {
 | 
			
		||||
		suffix = "embeddings/"
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(meta.ActualModelName, "tao-8k") {
 | 
			
		||||
		suffix = "embeddings/"
 | 
			
		||||
	}
 | 
			
		||||
	switch meta.ActualModelName {
 | 
			
		||||
@@ -45,6 +52,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
		suffix += "bloomz_7b1"
 | 
			
		||||
	case "Embedding-V1":
 | 
			
		||||
		suffix += "embedding-v1"
 | 
			
		||||
	case "bge-large-zh":
 | 
			
		||||
		suffix += "bge_large_zh"
 | 
			
		||||
	case "bge-large-en":
 | 
			
		||||
		suffix += "bge_large_en"
 | 
			
		||||
	case "tao-8k":
 | 
			
		||||
		suffix += "tao_8k"
 | 
			
		||||
	default:
 | 
			
		||||
		suffix += meta.ActualModelName
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,4 +7,7 @@ var ModelList = []string{
 | 
			
		||||
	"ERNIE-Speed",
 | 
			
		||||
	"ERNIE-Bot-turbo",
 | 
			
		||||
	"Embedding-V1",
 | 
			
		||||
	"bge-large-zh",
 | 
			
		||||
	"bge-large-en",
 | 
			
		||||
	"tao-8k",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,9 +32,16 @@ type Message struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Messages []Message `json:"messages"`
 | 
			
		||||
	Stream   bool      `json:"stream"`
 | 
			
		||||
	UserId   string    `json:"user_id,omitempty"`
 | 
			
		||||
	Messages        []Message `json:"messages"`
 | 
			
		||||
	Temperature     float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP            float64   `json:"top_p,omitempty"`
 | 
			
		||||
	PenaltyScore    float64   `json:"penalty_score,omitempty"`
 | 
			
		||||
	Stream          bool      `json:"stream,omitempty"`
 | 
			
		||||
	System          string    `json:"system,omitempty"`
 | 
			
		||||
	DisableSearch   bool      `json:"disable_search,omitempty"`
 | 
			
		||||
	EnableCitation  bool      `json:"enable_citation,omitempty"`
 | 
			
		||||
	MaxOutputTokens int       `json:"max_output_tokens,omitempty"`
 | 
			
		||||
	UserId          string    `json:"user_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Error struct {
 | 
			
		||||
@@ -45,28 +52,28 @@ type Error struct {
 | 
			
		||||
var baiduTokenStore sync.Map
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
	messages := make([]Message, 0, len(request.Messages))
 | 
			
		||||
	baiduRequest := ChatRequest{
 | 
			
		||||
		Messages:        make([]Message, 0, len(request.Messages)),
 | 
			
		||||
		Temperature:     request.Temperature,
 | 
			
		||||
		TopP:            request.TopP,
 | 
			
		||||
		PenaltyScore:    request.FrequencyPenalty,
 | 
			
		||||
		Stream:          request.Stream,
 | 
			
		||||
		DisableSearch:   false,
 | 
			
		||||
		EnableCitation:  false,
 | 
			
		||||
		MaxOutputTokens: request.MaxTokens,
 | 
			
		||||
		UserId:          request.User,
 | 
			
		||||
	}
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
			messages = append(messages, Message{
 | 
			
		||||
				Role:    "user",
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
			messages = append(messages, Message{
 | 
			
		||||
				Role:    "assistant",
 | 
			
		||||
				Content: "Okay",
 | 
			
		||||
			})
 | 
			
		||||
			baiduRequest.System = message.StringContent()
 | 
			
		||||
		} else {
 | 
			
		||||
			messages = append(messages, Message{
 | 
			
		||||
			baiduRequest.Messages = append(baiduRequest.Messages, Message{
 | 
			
		||||
				Role:    message.Role,
 | 
			
		||||
				Content: message.StringContent(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return &ChatRequest{
 | 
			
		||||
		Messages: messages,
 | 
			
		||||
		Stream:   request.Stream,
 | 
			
		||||
	}
 | 
			
		||||
	return &baiduRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										9
									
								
								relay/channel/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								relay/channel/lingyiwanwu/constants.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
package lingyiwanwu
 | 
			
		||||
 | 
			
		||||
// https://platform.lingyiwanwu.com/docs
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"yi-34b-chat-0205",
 | 
			
		||||
	"yi-34b-chat-200k",
 | 
			
		||||
	"yi-vl-plus",
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										65
									
								
								relay/channel/ollama/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								relay/channel/ollama/adaptor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
package ollama
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/util"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Adaptor struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) Init(meta *util.RelayMeta) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
	// https://github.com/ollama/ollama/blob/main/docs/api.md
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
 | 
			
		||||
	return fullRequestURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
 | 
			
		||||
	channel.SetupCommonRequestHeader(c, req, meta)
 | 
			
		||||
	req.Header.Set("Authorization", "Bearer "+meta.APIKey)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case constant.RelayModeEmbeddings:
 | 
			
		||||
		return nil, errors.New("not supported")
 | 
			
		||||
	default:
 | 
			
		||||
		return ConvertRequest(*request), nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
	return channel.DoRequestHelper(a, c, meta, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		err, usage = StreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = Handler(c, resp)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetModelList() []string {
 | 
			
		||||
	return ModelList
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetChannelName() string {
 | 
			
		||||
	return "ollama"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										5
									
								
								relay/channel/ollama/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/channel/ollama/constants.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
package ollama
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"qwen:0.5b-chat",
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										178
									
								
								relay/channel/ollama/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								relay/channel/ollama/main.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,178 @@
 | 
			
		||||
package ollama
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
	ollamaRequest := ChatRequest{
 | 
			
		||||
		Model: request.Model,
 | 
			
		||||
		Options: &Options{
 | 
			
		||||
			Seed:             int(request.Seed),
 | 
			
		||||
			Temperature:      request.Temperature,
 | 
			
		||||
			TopP:             request.TopP,
 | 
			
		||||
			FrequencyPenalty: request.FrequencyPenalty,
 | 
			
		||||
			PresencePenalty:  request.PresencePenalty,
 | 
			
		||||
		},
 | 
			
		||||
		Stream: request.Stream,
 | 
			
		||||
	}
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		ollamaRequest.Messages = append(ollamaRequest.Messages, Message{
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
			Content: message.StringContent(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &ollamaRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
 | 
			
		||||
	choice := openai.TextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: model.Message{
 | 
			
		||||
			Role:    response.Message.Role,
 | 
			
		||||
			Content: response.Message.Content,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	if response.Done {
 | 
			
		||||
		choice.FinishReason = "stop"
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := openai.TextResponse{
 | 
			
		||||
		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: helper.GetTimestamp(),
 | 
			
		||||
		Choices: []openai.TextResponseChoice{choice},
 | 
			
		||||
		Usage: model.Usage{
 | 
			
		||||
			PromptTokens:     response.PromptEvalCount,
 | 
			
		||||
			CompletionTokens: response.EvalCount,
 | 
			
		||||
			TotalTokens:      response.PromptEvalCount + response.EvalCount,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Role = ollamaResponse.Message.Role
 | 
			
		||||
	choice.Delta.Content = ollamaResponse.Message.Content
 | 
			
		||||
	if ollamaResponse.Done {
 | 
			
		||||
		choice.FinishReason = &constant.StopFinishReason
 | 
			
		||||
	}
 | 
			
		||||
	response := openai.ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: helper.GetTimestamp(),
 | 
			
		||||
		Model:   ollamaResponse.Model,
 | 
			
		||||
		Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &response
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "}\n"); i >= 0 {
 | 
			
		||||
			return i + 2, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := strings.TrimPrefix(scanner.Text(), "}")
 | 
			
		||||
			dataChan <- data + "}"
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var ollamaResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &ollamaResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if ollamaResponse.EvalCount != 0 {
 | 
			
		||||
				usage.PromptTokens = ollamaResponse.PromptEvalCount
 | 
			
		||||
				usage.CompletionTokens = ollamaResponse.EvalCount
 | 
			
		||||
				usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseOllama2OpenAI(&ollamaResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	ctx := context.TODO()
 | 
			
		||||
	var ollamaResponse ChatResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf(ctx, "ollama response: %s", string(responseBody))
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &ollamaResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if ollamaResponse.Error != "" {
 | 
			
		||||
		return &model.ErrorWithStatusCode{
 | 
			
		||||
			Error: model.Error{
 | 
			
		||||
				Message: ollamaResponse.Error,
 | 
			
		||||
				Type:    "ollama_error",
 | 
			
		||||
				Param:   "",
 | 
			
		||||
				Code:    "ollama_error",
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseOllama2OpenAI(&ollamaResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										37
									
								
								relay/channel/ollama/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								relay/channel/ollama/model.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
			
		||||
package ollama
 | 
			
		||||
 | 
			
		||||
type Options struct {
 | 
			
		||||
	Seed             int     `json:"seed,omitempty"`
 | 
			
		||||
	Temperature      float64 `json:"temperature,omitempty"`
 | 
			
		||||
	TopK             int     `json:"top_k,omitempty"`
 | 
			
		||||
	TopP             float64 `json:"top_p,omitempty"`
 | 
			
		||||
	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
 | 
			
		||||
	PresencePenalty  float64 `json:"presence_penalty,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role    string   `json:"role,omitempty"`
 | 
			
		||||
	Content string   `json:"content,omitempty"`
 | 
			
		||||
	Images  []string `json:"images,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
	Model    string    `json:"model,omitempty"`
 | 
			
		||||
	Messages []Message `json:"messages,omitempty"`
 | 
			
		||||
	Stream   bool      `json:"stream"`
 | 
			
		||||
	Options  *Options  `json:"options,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatResponse struct {
 | 
			
		||||
	Model           string  `json:"model,omitempty"`
 | 
			
		||||
	CreatedAt       string  `json:"created_at,omitempty"`
 | 
			
		||||
	Message         Message `json:"message,omitempty"`
 | 
			
		||||
	Response        string  `json:"response,omitempty"` // for stream response
 | 
			
		||||
	Done            bool    `json:"done,omitempty"`
 | 
			
		||||
	TotalDuration   int     `json:"total_duration,omitempty"`
 | 
			
		||||
	LoadDuration    int     `json:"load_duration,omitempty"`
 | 
			
		||||
	PromptEvalCount int     `json:"prompt_eval_count,omitempty"`
 | 
			
		||||
	EvalCount       int     `json:"eval_count,omitempty"`
 | 
			
		||||
	EvalDuration    int     `json:"eval_duration,omitempty"`
 | 
			
		||||
	Error           string  `json:"error,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/ai360"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/baichuan"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/groq"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/minimax"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/mistral"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/moonshot"
 | 
			
		||||
@@ -18,6 +19,7 @@ var CompatibleChannels = []int{
 | 
			
		||||
	common.ChannelTypeMinimax,
 | 
			
		||||
	common.ChannelTypeMistral,
 | 
			
		||||
	common.ChannelTypeGroq,
 | 
			
		||||
	common.ChannelTypeLingYiWanWu,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
 | 
			
		||||
@@ -36,6 +38,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
 | 
			
		||||
		return "mistralai", mistral.ModelList
 | 
			
		||||
	case common.ChannelTypeGroq:
 | 
			
		||||
		return "groq", groq.ModelList
 | 
			
		||||
	case common.ChannelTypeLingYiWanWu:
 | 
			
		||||
		return "lingyiwanwu", lingyiwanwu.ModelList
 | 
			
		||||
	default:
 | 
			
		||||
		return "openai", ModelList
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,7 @@ const (
 | 
			
		||||
	APITypeAIProxyLibrary
 | 
			
		||||
	APITypeTencent
 | 
			
		||||
	APITypeGemini
 | 
			
		||||
	APITypeOllama
 | 
			
		||||
 | 
			
		||||
	APITypeDummy // this one is only for count, do not add any channel after this
 | 
			
		||||
)
 | 
			
		||||
@@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int {
 | 
			
		||||
		apiType = APITypeTencent
 | 
			
		||||
	case common.ChannelTypeGemini:
 | 
			
		||||
		apiType = APITypeGemini
 | 
			
		||||
	case common.ChannelTypeOllama:
 | 
			
		||||
		apiType = APITypeOllama
 | 
			
		||||
	}
 | 
			
		||||
	return apiType
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -50,14 +50,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	modelRatio := common.GetModelRatio(audioModel)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	var quota int
 | 
			
		||||
	var preConsumedQuota int
 | 
			
		||||
	var quota int64
 | 
			
		||||
	var preConsumedQuota int64
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case constant.RelayModeAudioSpeech:
 | 
			
		||||
		preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
 | 
			
		||||
		preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
 | 
			
		||||
		quota = preConsumedQuota
 | 
			
		||||
	default:
 | 
			
		||||
		preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
 | 
			
		||||
		preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
 | 
			
		||||
	}
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(ctx, userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -184,7 +184,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		quota = openai.CountTokenText(text, audioModel)
 | 
			
		||||
		quota = int64(openai.CountTokenText(text, audioModel))
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
 
 | 
			
		||||
@@ -107,15 +107,15 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int
 | 
			
		||||
	return 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
 | 
			
		||||
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 {
 | 
			
		||||
	preConsumedTokens := config.PreConsumedQuota
 | 
			
		||||
	if textRequest.MaxTokens != 0 {
 | 
			
		||||
		preConsumedTokens = promptTokens + textRequest.MaxTokens
 | 
			
		||||
		preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens)
 | 
			
		||||
	}
 | 
			
		||||
	return int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	return int64(float64(preConsumedTokens) * ratio)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
 | 
			
		||||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
 | 
			
		||||
	preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
 | 
			
		||||
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
 | 
			
		||||
@@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
 | 
			
		||||
	return preConsumedQuota, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
 | 
			
		||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
 | 
			
		||||
	if usage == nil {
 | 
			
		||||
		logger.Error(ctx, "usage is nil, which is unexpected")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	quota := 0
 | 
			
		||||
	var quota int64
 | 
			
		||||
	completionRatio := common.GetCompletionRatio(textRequest.Model)
 | 
			
		||||
	promptTokens := usage.PromptTokens
 | 
			
		||||
	completionTokens := usage.CompletionTokens
 | 
			
		||||
	quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
 | 
			
		||||
	quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
 | 
			
		||||
	if ratio != 0 && quota <= 0 {
 | 
			
		||||
		quota = 1
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -81,7 +81,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
 | 
			
		||||
 | 
			
		||||
	quota := int(ratio*imageCostRatio*1000) * imageRequest.N
 | 
			
		||||
	quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
 | 
			
		||||
 | 
			
		||||
	if userQuota-quota < 0 {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
 | 
			
		||||
 
 | 
			
		||||
@@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonData)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/anthropic"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/baidu"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/gemini"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/ollama"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/palm"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channel/tencent"
 | 
			
		||||
@@ -37,6 +38,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 | 
			
		||||
		return &xunfei.Adaptor{}
 | 
			
		||||
	case constant.APITypeZhipu:
 | 
			
		||||
		return &zhipu.Adaptor{}
 | 
			
		||||
	case constant.APITypeOllama:
 | 
			
		||||
		return &ollama.Adaptor{}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
 | 
			
		||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
 | 
			
		||||
	if preConsumedQuota != 0 {
 | 
			
		||||
		go func(ctx context.Context) {
 | 
			
		||||
			// return pre-consumed quota
 | 
			
		||||
 
 | 
			
		||||
@@ -155,7 +155,7 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
 | 
			
		||||
	return fullRequestURL
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
 | 
			
		||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
 | 
			
		||||
	// quotaDelta is remaining quota to be consumed
 | 
			
		||||
	err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -168,7 +168,7 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
 | 
			
		||||
	// totalQuota is total quota consumed
 | 
			
		||||
	if totalQuota != 0 {
 | 
			
		||||
		logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
		model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
 | 
			
		||||
		model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
 | 
			
		||||
		model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
 | 
			
		||||
		model.UpdateChannelUsedQuota(channelId, totalQuota)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func SetDashboardRouter(router *gin.Engine) {
 | 
			
		||||
	apiRouter := router.Group("/")
 | 
			
		||||
	apiRouter.Use(middleware.CORS())
 | 
			
		||||
	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
	apiRouter.Use(middleware.GlobalAPIRateLimit())
 | 
			
		||||
	apiRouter.Use(middleware.TokenAuth())
 | 
			
		||||
 
 | 
			
		||||
@@ -95,6 +95,18 @@ export const CHANNEL_OPTIONS = {
 | 
			
		||||
    value: 29,
 | 
			
		||||
    color: 'default'
 | 
			
		||||
  },
 | 
			
		||||
  30: {
 | 
			
		||||
    key: 30,
 | 
			
		||||
    text: 'Ollama',
 | 
			
		||||
    value: 30,
 | 
			
		||||
    color: 'default'
 | 
			
		||||
  },
 | 
			
		||||
  31: {
 | 
			
		||||
    key: 31,
 | 
			
		||||
    text: '零一万物',
 | 
			
		||||
    value: 31,
 | 
			
		||||
    color: 'default'
 | 
			
		||||
  },
 | 
			
		||||
  8: {
 | 
			
		||||
    key: 8,
 | 
			
		||||
    text: '自定义渠道',
 | 
			
		||||
 
 | 
			
		||||
@@ -166,6 +166,12 @@ const typeConfig = {
 | 
			
		||||
  29: {
 | 
			
		||||
    modelGroup: "groq",
 | 
			
		||||
  },
 | 
			
		||||
  30: {
 | 
			
		||||
    modelGroup: "ollama",
 | 
			
		||||
  },
 | 
			
		||||
  31: {
 | 
			
		||||
    modelGroup: "lingyiwanwu",
 | 
			
		||||
  },
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
export { defaultConfig, typeConfig };
 | 
			
		||||
 
 | 
			
		||||
@@ -265,7 +265,7 @@ const OtherSetting = () => {
 | 
			
		||||
                  multiline
 | 
			
		||||
                  maxRows={15}
 | 
			
		||||
                  id="Footer"
 | 
			
		||||
                  label="公告"
 | 
			
		||||
                  label="页脚"
 | 
			
		||||
                  value={inputs.Footer}
 | 
			
		||||
                  name="Footer"
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@ const COPY_OPTIONS = [
 | 
			
		||||
    url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}',
 | 
			
		||||
    encode: false
 | 
			
		||||
  },
 | 
			
		||||
  { key: 'ama', text: 'AMA 问天', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
 | 
			
		||||
  { key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
 | 
			
		||||
  { key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,12 +8,12 @@ import { renderQuota } from '../helpers/render';
 | 
			
		||||
 | 
			
		||||
const COPY_OPTIONS = [
 | 
			
		||||
  { key: 'next', text: 'ChatGPT Next Web', value: 'next' },
 | 
			
		||||
  { key: 'ama', text: 'AMA 问天', value: 'ama' },
 | 
			
		||||
  { key: 'ama', text: 'BotGem', value: 'ama' },
 | 
			
		||||
  { key: 'opencat', text: 'OpenCat', value: 'opencat' },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
const OPEN_LINK_OPTIONS = [
 | 
			
		||||
  { key: 'ama', text: 'AMA 问天', value: 'ama' },
 | 
			
		||||
  { key: 'ama', text: 'BotGem', value: 'ama' },
 | 
			
		||||
  { key: 'opencat', text: 'OpenCat', value: 'opencat' },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,8 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 26, text: '百川大模型', value: 26, color: 'orange' },
 | 
			
		||||
  { key: 27, text: 'MiniMax', value: 27, color: 'red' },
 | 
			
		||||
  { key: 29, text: 'Groq', value: 29, color: 'orange' },
 | 
			
		||||
  { key: 30, text: 'Ollama', value: 30, color: 'black' },
 | 
			
		||||
  { key: 31, text: '零一万物', value: 31, color: 'green' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user