mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-19 10:06:37 +08:00
Merge remote-tracking branch 'origin/upstream/main'
This commit is contained in:
commit
41afad713e
6
.github/workflows/linux-release.yml
vendored
6
.github/workflows/linux-release.yml
vendored
@ -20,6 +20,12 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 16
|
node-version: 16
|
||||||
|
6
.github/workflows/macos-release.yml
vendored
6
.github/workflows/macos-release.yml
vendored
@ -20,6 +20,12 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 16
|
node-version: 16
|
||||||
|
6
.github/workflows/windows-release.yml
vendored
6
.github/workflows/windows-release.yml
vendored
@ -23,6 +23,12 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Check repository URL
|
||||||
|
run: |
|
||||||
|
REPO_URL=$(git config --get remote.origin.url)
|
||||||
|
if [[ $REPO_URL == *"pro" ]]; then
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 16
|
node-version: 16
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -90,14 +90,14 @@ var MessagePusherToken = ""
|
|||||||
var TurnstileSiteKey = ""
|
var TurnstileSiteKey = ""
|
||||||
var TurnstileSecretKey = ""
|
var TurnstileSecretKey = ""
|
||||||
|
|
||||||
var QuotaForNewUser = 0
|
var QuotaForNewUser int64 = 0
|
||||||
var QuotaForInviter = 0
|
var QuotaForInviter int64 = 0
|
||||||
var QuotaForInvitee = 0
|
var QuotaForInvitee int64 = 0
|
||||||
var ChannelDisableThreshold = 5.0
|
var ChannelDisableThreshold = 5.0
|
||||||
var AutomaticDisableChannelEnabled = false
|
var AutomaticDisableChannelEnabled = false
|
||||||
var AutomaticEnableChannelEnabled = false
|
var AutomaticEnableChannelEnabled = false
|
||||||
var QuotaRemindThreshold = 1000
|
var QuotaRemindThreshold int64 = 1000
|
||||||
var PreConsumedQuota = 500
|
var PreConsumedQuota int64 = 500
|
||||||
var ApproximateTokenEnabled = false
|
var ApproximateTokenEnabled = false
|
||||||
var RetryTimes = 0
|
var RetryTimes = 0
|
||||||
|
|
||||||
@ -108,17 +108,17 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
|
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
|
||||||
|
|
||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
|
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
|
||||||
var IdleTimeout = helper.GetOrDefaultEnvInt("IDLE_TIMEOUT", 30) // unit is second
|
var IdleTimeout = env.Int("IDLE_TIMEOUT", 30) // unit is second
|
||||||
|
|
||||||
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||||
|
|
||||||
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
|
var Theme = env.String("THEME", "default")
|
||||||
var ValidThemes = map[string]bool{
|
var ValidThemes = map[string]bool{
|
||||||
"default": true,
|
"default": true,
|
||||||
"berry": true,
|
"berry": true,
|
||||||
@ -127,10 +127,10 @@ var ValidThemes = map[string]bool{
|
|||||||
// All duration's unit is seconds
|
// All duration's unit is seconds
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
var (
|
var (
|
||||||
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
|
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
UploadRateLimitNum = 10
|
||||||
@ -145,8 +145,8 @@ var (
|
|||||||
|
|
||||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||||
|
|
||||||
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false)
|
var EnableMetric = env.Bool("ENABLE_METRIC", false)
|
||||||
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10)
|
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
|
||||||
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
|
||||||
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
|
||||||
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128)
|
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
|
||||||
|
@ -69,6 +69,8 @@ const (
|
|||||||
ChannelTypeMinimax
|
ChannelTypeMinimax
|
||||||
ChannelTypeMistral
|
ChannelTypeMistral
|
||||||
ChannelTypeGroq
|
ChannelTypeGroq
|
||||||
|
ChannelTypeOllama
|
||||||
|
ChannelTypeLingYiWanWu
|
||||||
|
|
||||||
ChannelTypeDummy
|
ChannelTypeDummy
|
||||||
)
|
)
|
||||||
@ -104,6 +106,8 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.minimax.chat", // 27
|
"https://api.minimax.chat", // 27
|
||||||
"https://api.mistral.ai", // 28
|
"https://api.mistral.ai", // 28
|
||||||
"https://api.groq.com/openai", // 29
|
"https://api.groq.com/openai", // 29
|
||||||
|
"http://localhost:11434", // 30
|
||||||
|
"https://api.lingyiwanwu.com", // 31
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "github.com/songquanpeng/one-api/common/helper"
|
import (
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
|
)
|
||||||
|
|
||||||
var UsingSQLite = false
|
var UsingSQLite = false
|
||||||
var UsingPostgreSQL = false
|
var UsingPostgreSQL = false
|
||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db"
|
var SQLitePath = "one-api.db"
|
||||||
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
|
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)
|
||||||
|
42
common/env/helper.go
vendored
Normal file
42
common/env/helper.go
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Bool(env string, defaultValue bool) bool {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func Int(env string, defaultValue int) int {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.Atoi(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64(env string, defaultValue float64) float64 {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func String(env string, defaultValue string) string {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env)
|
||||||
|
}
|
@ -3,12 +3,10 @@ package helper
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
|
||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -187,6 +185,10 @@ func GetTimeString() string {
|
|||||||
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenRequestID() string {
|
||||||
|
return GetTimeString() + GetRandomNumberString(8)
|
||||||
|
}
|
||||||
|
|
||||||
func Max(a int, b int) int {
|
func Max(a int, b int) int {
|
||||||
if a >= b {
|
if a >= b {
|
||||||
return a
|
return a
|
||||||
@ -195,44 +197,6 @@ func Max(a int, b int) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env) == "true"
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultEnvInt(env string, defaultValue int) int {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.Atoi(os.Getenv(env))
|
|
||||||
if err != nil {
|
|
||||||
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.ParseFloat(os.Getenv(env), 64)
|
|
||||||
if err != nil {
|
|
||||||
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOrDefaultEnvString(env string, defaultValue string) string {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env)
|
|
||||||
}
|
|
||||||
|
|
||||||
func AssignOrDefault(value string, defaultValue string) string {
|
func AssignOrDefault(value string, defaultValue string) string {
|
||||||
if len(value) != 0 {
|
if len(value) != 0 {
|
||||||
return value
|
return value
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
@ -54,7 +56,9 @@ func SysError(s string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Debug(ctx context.Context, msg string) {
|
func Debug(ctx context.Context, msg string) {
|
||||||
logHelper(ctx, loggerDEBUG, msg)
|
if config.DebugEnabled {
|
||||||
|
logHelper(ctx, loggerDEBUG, msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Info(ctx context.Context, msg string) {
|
func Info(ctx context.Context, msg string) {
|
||||||
@ -91,6 +95,9 @@ func logHelper(ctx context.Context, level string, msg string) {
|
|||||||
writer = gin.DefaultWriter
|
writer = gin.DefaultWriter
|
||||||
}
|
}
|
||||||
id := ctx.Value(RequestIdKey)
|
id := ctx.Value(RequestIdKey)
|
||||||
|
if id == nil {
|
||||||
|
id = helper.GenRequestID()
|
||||||
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
if !setupLogWorking {
|
if !setupLogWorking {
|
||||||
|
@ -69,7 +69,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"claude-instant-1.2": 0.8 / 1000 * USD,
|
"claude-instant-1.2": 0.8 / 1000 * USD,
|
||||||
"claude-2.0": 8.0 / 1000 * USD,
|
"claude-2.0": 8.0 / 1000 * USD,
|
||||||
"claude-2.1": 8.0 / 1000 * USD,
|
"claude-2.1": 8.0 / 1000 * USD,
|
||||||
"claude-3-haiku-20240229": 0.25 / 1000 * USD,
|
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
|
||||||
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
|
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
|
||||||
"claude-3-opus-20240229": 15.0 / 1000 * USD,
|
"claude-3-opus-20240229": 15.0 / 1000 * USD,
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
||||||
@ -78,6 +78,9 @@ var ModelRatio = map[string]float64{
|
|||||||
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
|
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
|
||||||
"ERNIE-Bot-8k": 0.024 * RMB,
|
"ERNIE-Bot-8k": 0.024 * RMB,
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
|
"bge-large-zh": 0.002 * RMB,
|
||||||
|
"bge-large-en": 0.002 * RMB,
|
||||||
|
"bge-large-8k": 0.002 * RMB,
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
@ -130,6 +133,10 @@ var ModelRatio = map[string]float64{
|
|||||||
"llama2-7b-2048": 0.1 / 1000 * USD,
|
"llama2-7b-2048": 0.1 / 1000 * USD,
|
||||||
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
|
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
|
||||||
"gemma-7b-it": 0.1 / 1000 * USD,
|
"gemma-7b-it": 0.1 / 1000 * USD,
|
||||||
|
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||||
|
"yi-34b-chat-0205": 2.5 / 1000000 * RMB,
|
||||||
|
"yi-34b-chat-200k": 12.0 / 1000000 * RMB,
|
||||||
|
"yi-vl-plus": 6.0 / 1000000 * RMB,
|
||||||
}
|
}
|
||||||
|
|
||||||
var CompletionRatio = map[string]float64{}
|
var CompletionRatio = map[string]float64{}
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func LogQuota(quota int) string {
|
func LogQuota(quota int64) string {
|
||||||
if config.DisplayInCurrencyEnabled {
|
if config.DisplayInCurrencyEnabled {
|
||||||
return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
|
return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
|
||||||
} else {
|
} else {
|
||||||
|
@ -8,8 +8,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetSubscription(c *gin.Context) {
|
func GetSubscription(c *gin.Context) {
|
||||||
var remainQuota int
|
var remainQuota int64
|
||||||
var usedQuota int
|
var usedQuota int64
|
||||||
var err error
|
var err error
|
||||||
var token *model.Token
|
var token *model.Token
|
||||||
var expiredTime int64
|
var expiredTime int64
|
||||||
@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsage(c *gin.Context) {
|
func GetUsage(c *gin.Context) {
|
||||||
var quota int
|
var quota int64
|
||||||
var err error
|
var err error
|
||||||
var token *model.Token
|
var token *model.Token
|
||||||
if config.DisplayTokenStatEnabled {
|
if config.DisplayTokenStatEnabled {
|
||||||
|
@ -30,7 +30,7 @@ import (
|
|||||||
|
|
||||||
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
|
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
|
||||||
testRequest := &relaymodel.GeneralOpenAIRequest{
|
testRequest := &relaymodel.GeneralOpenAIRequest{
|
||||||
MaxTokens: 1,
|
MaxTokens: 2,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
Model: "gpt-3.5-turbo",
|
Model: "gpt-3.5-turbo",
|
||||||
}
|
}
|
||||||
|
@ -234,18 +234,18 @@ func UpdateToken(c *gin.Context) {
|
|||||||
tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime
|
tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime
|
||||||
}
|
}
|
||||||
if tokenPatch.RemainQuota != nil {
|
if tokenPatch.RemainQuota != nil {
|
||||||
tokenInDB.RemainQuota = *tokenPatch.RemainQuota
|
tokenInDB.RemainQuota = int64(*tokenPatch.RemainQuota)
|
||||||
}
|
}
|
||||||
if tokenPatch.UnlimitedQuota != nil {
|
if tokenPatch.UnlimitedQuota != nil {
|
||||||
tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota
|
tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenInDB.RemainQuota -= tokenPatch.AddUsedQuota
|
tokenInDB.RemainQuota -= int64(tokenPatch.AddUsedQuota)
|
||||||
tokenInDB.UsedQuota += tokenPatch.AddUsedQuota
|
tokenInDB.UsedQuota += int64(tokenPatch.AddUsedQuota)
|
||||||
|
|
||||||
if tokenPatch.AddUsedQuota != 0 {
|
if tokenPatch.AddUsedQuota != 0 {
|
||||||
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(tokenPatch.AddUsedQuota)))
|
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(int64(tokenPatch.AddUsedQuota))))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = tokenInDB.Update(); err != nil {
|
if err = tokenInDB.Update(); err != nil {
|
||||||
|
@ -2,7 +2,7 @@ version: '3.4'
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
one-api:
|
one-api:
|
||||||
image: justsong/one-api:latest
|
image: "${REGISTRY:-docker.io}/justsong/one-api:latest"
|
||||||
container_name: one-api
|
container_name: one-api
|
||||||
restart: always
|
restart: always
|
||||||
command: --log-dir /app/logs
|
command: --log-dir /app/logs
|
||||||
@ -29,12 +29,12 @@ services:
|
|||||||
retries: 3
|
retries: 3
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
image: redis:latest
|
image: "${REGISTRY:-docker.io}/redis:latest"
|
||||||
container_name: redis
|
container_name: redis
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
db:
|
db:
|
||||||
image: mysql:8.2.0
|
image: "${REGISTRY:-docker.io}/mysql:8.2.0"
|
||||||
restart: always
|
restart: always
|
||||||
container_name: mysql
|
container_name: mysql
|
||||||
volumes:
|
volumes:
|
||||||
|
16
main.go
16
main.go
@ -32,11 +32,25 @@ func main() {
|
|||||||
if config.DebugEnabled {
|
if config.DebugEnabled {
|
||||||
logger.SysLog("running in debug mode")
|
logger.SysLog("running in debug mode")
|
||||||
}
|
}
|
||||||
|
var err error
|
||||||
// Initialize SQL Database
|
// Initialize SQL Database
|
||||||
err := model.InitDB()
|
model.DB, err = model.InitDB("SQL_DSN")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to initialize database: " + err.Error())
|
logger.FatalLog("failed to initialize database: " + err.Error())
|
||||||
}
|
}
|
||||||
|
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||||
|
logger.SysLog("using secondary database for table logs")
|
||||||
|
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("failed to initialize secondary database: " + err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
model.LOG_DB = model.DB
|
||||||
|
}
|
||||||
|
err = model.CreateRootAccountIfNeed()
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("database init error: " + err.Error())
|
||||||
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
err := model.CloseDB()
|
err := model.CloseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
func RequestId() func(c *gin.Context) {
|
func RequestId() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := helper.GetTimeString() + helper.GetRandomNumberString(8)
|
id := helper.GenRequestID()
|
||||||
c.Set(logger.RequestIdKey, id)
|
c.Set(logger.RequestIdKey, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
|
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/Laisky/errors/v2"
|
"github.com/Laisky/errors/v2"
|
||||||
@ -71,31 +72,42 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
|||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetUserQuota(id int) (quota int, err error) {
|
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
|
||||||
|
quota, err = GetUserQuota(id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "Redis set user quota error: "+err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return GetUserQuota(id)
|
return GetUserQuota(id)
|
||||||
}
|
}
|
||||||
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
quota, err = GetUserQuota(id)
|
return fetchAndUpdateUserQuota(ctx, id)
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
logger.SysError("Redis set user quota error: " + err.Error())
|
|
||||||
}
|
|
||||||
return quota, err
|
|
||||||
}
|
}
|
||||||
quota, err = strconv.Atoi(quotaString)
|
quota, err = strconv.ParseInt(quotaString, 10, 64)
|
||||||
return quota, err
|
if err != nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
|
||||||
|
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
|
||||||
|
return fetchAndUpdateUserQuota(ctx, id)
|
||||||
|
}
|
||||||
|
return quota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheUpdateUserQuota(id int) error {
|
func CacheUpdateUserQuota(ctx context.Context, id int) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
quota, err := CacheGetUserQuota(id)
|
quota, err := CacheGetUserQuota(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -103,7 +115,7 @@ func CacheUpdateUserQuota(id int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheDecreaseUserQuota(id int, quota int) error {
|
func CacheDecreaseUserQuota(id int, quota int64) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -178,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelUsedQuota(id int, quota int) {
|
func UpdateChannelUsedQuota(id int, quota int64) {
|
||||||
if config.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||||||
return
|
return
|
||||||
@ -186,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
|
|||||||
updateChannelUsedQuota(id, quota)
|
updateChannelUsedQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateChannelUsedQuota(id int, quota int) {
|
func updateChannelUsedQuota(id int, quota int64) {
|
||||||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to update channel used quota: " + err.Error())
|
logger.SysError("failed to update channel used quota: " + err.Error())
|
||||||
|
30
model/log.go
30
model/log.go
@ -45,13 +45,13 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
Type: logType,
|
Type: logType,
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("failed to record log: " + err.Error())
|
logger.SysError("failed to record log: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
|
||||||
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !config.LogConsumeEnabled {
|
if !config.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
@ -66,10 +66,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
Quota: quota,
|
Quota: int(quota),
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := LOG_DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "failed to record log: "+err.Error())
|
logger.Error(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB
|
tx = LOG_DB
|
||||||
} else {
|
} else {
|
||||||
tx = DB.Where("type = ?", logType)
|
tx = LOG_DB.Where("type = ?", logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB.Where("user_id = ?", userId)
|
tx = LOG_DB.Where("user_id = ?", userId)
|
||||||
} else {
|
} else {
|
||||||
tx = DB.Where("user_id = ? and type = ?", userId, logType)
|
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||||
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
|
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||||
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
|
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
|
||||||
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
|
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||||
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
|
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||||
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||||
return result.RowsAffected, result.Error
|
return result.RowsAffected, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis
|
|||||||
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DB.Raw(`
|
err = LOG_DB.Raw(`
|
||||||
SELECT `+groupSelect+`,
|
SELECT `+groupSelect+`,
|
||||||
model_name, count(1) as request_count,
|
model_name, count(1) as request_count,
|
||||||
sum(quota) as quota,
|
sum(quota) as quota,
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
@ -18,8 +19,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
|
var LOG_DB *gorm.DB
|
||||||
|
|
||||||
func createRootAccountIfNeed() error {
|
func CreateRootAccountIfNeed() error {
|
||||||
var user User
|
var user User
|
||||||
//if user.Status != util.UserStatusEnabled {
|
//if user.Status != util.UserStatusEnabled {
|
||||||
if err := DB.First(&user).Error; err != nil {
|
if err := DB.First(&user).Error; err != nil {
|
||||||
@ -42,9 +44,9 @@ func createRootAccountIfNeed() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func chooseDB() (*gorm.DB, error) {
|
func chooseDB(envName string) (*gorm.DB, error) {
|
||||||
if os.Getenv("SQL_DSN") != "" {
|
if os.Getenv(envName) != "" {
|
||||||
dsn := os.Getenv("SQL_DSN")
|
dsn := os.Getenv(envName)
|
||||||
if strings.HasPrefix(dsn, "postgres://") {
|
if strings.HasPrefix(dsn, "postgres://") {
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
logger.SysLog("using PostgreSQL as database")
|
logger.SysLog("using PostgreSQL as database")
|
||||||
@ -72,23 +74,22 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitDB() (err error) {
|
func InitDB(envName string) (db *gorm.DB, err error) {
|
||||||
db, err := chooseDB()
|
db, err = chooseDB(envName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if config.DebugSQLEnabled {
|
if config.DebugSQLEnabled {
|
||||||
db = db.Debug()
|
db = db.Debug()
|
||||||
}
|
}
|
||||||
DB = db
|
sqlDB, err := db.DB()
|
||||||
sqlDB, err := DB.DB()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
|
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
|
||||||
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
|
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
|
||||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
|
||||||
|
|
||||||
if !config.IsMasterNode {
|
if !config.IsMasterNode {
|
||||||
return nil
|
return db, err
|
||||||
}
|
}
|
||||||
if common.UsingMySQL {
|
if common.UsingMySQL {
|
||||||
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||||
@ -96,46 +97,55 @@ func InitDB() (err error) {
|
|||||||
logger.SysLog("database migration started")
|
logger.SysLog("database migration started")
|
||||||
err = db.AutoMigrate(&Channel{})
|
err = db.AutoMigrate(&Channel{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Token{})
|
err = db.AutoMigrate(&Token{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&User{})
|
err = db.AutoMigrate(&User{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Option{})
|
err = db.AutoMigrate(&Option{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Redemption{})
|
err = db.AutoMigrate(&Redemption{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Ability{})
|
err = db.AutoMigrate(&Ability{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Log{})
|
err = db.AutoMigrate(&Log{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
logger.SysLog("database migrated")
|
logger.SysLog("database migrated")
|
||||||
err = createRootAccountIfNeed()
|
return db, err
|
||||||
return errors.WithStack(err)
|
|
||||||
} else {
|
} else {
|
||||||
logger.FatalLog(err)
|
logger.FatalLog(err)
|
||||||
}
|
}
|
||||||
return errors.WithStack(err)
|
return db, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CloseDB() error {
|
func closeDB(db *gorm.DB) error {
|
||||||
sqlDB, err := DB.DB()
|
sqlDB, err := db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
err = sqlDB.Close()
|
err = sqlDB.Close()
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CloseDB() error {
|
||||||
|
if LOG_DB != DB {
|
||||||
|
err := closeDB(LOG_DB)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return closeDB(DB)
|
||||||
|
}
|
||||||
|
@ -61,11 +61,11 @@ func InitOptionMap() {
|
|||||||
config.OptionMap["MessagePusherToken"] = ""
|
config.OptionMap["MessagePusherToken"] = ""
|
||||||
config.OptionMap["TurnstileSiteKey"] = ""
|
config.OptionMap["TurnstileSiteKey"] = ""
|
||||||
config.OptionMap["TurnstileSecretKey"] = ""
|
config.OptionMap["TurnstileSecretKey"] = ""
|
||||||
config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
|
config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
|
||||||
config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
|
config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
|
||||||
config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
|
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
|
||||||
config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
|
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
|
||||||
config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
|
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
|
||||||
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||||
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||||
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||||
@ -193,15 +193,15 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "TurnstileSecretKey":
|
case "TurnstileSecretKey":
|
||||||
config.TurnstileSecretKey = value
|
config.TurnstileSecretKey = value
|
||||||
case "QuotaForNewUser":
|
case "QuotaForNewUser":
|
||||||
config.QuotaForNewUser, _ = strconv.Atoi(value)
|
config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "QuotaForInviter":
|
case "QuotaForInviter":
|
||||||
config.QuotaForInviter, _ = strconv.Atoi(value)
|
config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "QuotaForInvitee":
|
case "QuotaForInvitee":
|
||||||
config.QuotaForInvitee, _ = strconv.Atoi(value)
|
config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "QuotaRemindThreshold":
|
case "QuotaRemindThreshold":
|
||||||
config.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "PreConsumedQuota":
|
case "PreConsumedQuota":
|
||||||
config.PreConsumedQuota, _ = strconv.Atoi(value)
|
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
|
||||||
case "RetryTimes":
|
case "RetryTimes":
|
||||||
config.RetryTimes, _ = strconv.Atoi(value)
|
config.RetryTimes, _ = strconv.Atoi(value)
|
||||||
case "ModelRatio":
|
case "ModelRatio":
|
||||||
|
@ -14,7 +14,7 @@ type Redemption struct {
|
|||||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Quota int `json:"quota" gorm:"default:100"`
|
Quota int64 `json:"quota" gorm:"default:100"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||||
@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
|
|||||||
return &redemption, err
|
return &redemption, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Redeem(key string, userId int) (quota int, err error) {
|
func Redeem(key string, userId int) (quota int64, err error) {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return 0, errors.New("未提供兑换码")
|
return 0, errors.New("未提供兑换码")
|
||||||
}
|
}
|
||||||
|
@ -21,9 +21,9 @@ type Token struct {
|
|||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
RemainQuota int64 `json:"remain_quota" gorm:"default:0"`
|
||||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||||
@ -141,7 +141,7 @@ func DeleteTokenById(id int, userId int) (err error) {
|
|||||||
return token.Delete()
|
return token.Delete()
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseTokenQuota(id int, quota int) (err error) {
|
func IncreaseTokenQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@ -152,7 +152,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
return increaseTokenQuota(id, quota)
|
return increaseTokenQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func increaseTokenQuota(id int, quota int) (err error) {
|
func increaseTokenQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||||||
@ -163,7 +163,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecreaseTokenQuota(id int, quota int) (err error) {
|
func DecreaseTokenQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@ -174,7 +174,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
return decreaseTokenQuota(id, quota)
|
return decreaseTokenQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decreaseTokenQuota(id int, quota int) (err error) {
|
func decreaseTokenQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||||||
@ -185,7 +185,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@ -235,7 +235,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
|
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||||
token, err := GetTokenById(tokenId)
|
token, err := GetTokenById(tokenId)
|
||||||
if quota > 0 {
|
if quota > 0 {
|
||||||
err = DecreaseUserQuota(token.UserId, quota)
|
err = DecreaseUserQuota(token.UserId, quota)
|
||||||
|
@ -2,6 +2,8 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Laisky/errors/v2"
|
"github.com/Laisky/errors/v2"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/blacklist"
|
"github.com/songquanpeng/one-api/common/blacklist"
|
||||||
@ -9,7 +11,6 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||||
@ -26,10 +27,10 @@ type User struct {
|
|||||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||||
Quota int `json:"quota" gorm:"column:quota;type:int;default:0"`
|
Quota int64 `json:"quota" gorm:"type:int;default:0"`
|
||||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
||||||
RequestCount int `json:"request_count" gorm:"column:request_count;type:int;default:0;"` // request number
|
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||||
Group string `json:"group" gorm:"column:group;type:varchar(32);default:'default'"`
|
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||||
}
|
}
|
||||||
@ -274,12 +275,12 @@ func ValidateAccessToken(token string) (user *User) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserQuota(id int) (quota int, err error) {
|
func GetUserQuota(id int) (quota int64, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||||
return quota, err
|
return quota, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserUsedQuota(id int) (quota int, err error) {
|
func GetUserUsedQuota(id int) (quota int64, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
|
||||||
return quota, err
|
return quota, err
|
||||||
}
|
}
|
||||||
@ -299,7 +300,7 @@ func GetUserGroup(id int) (group string, err error) {
|
|||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
func IncreaseUserQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@ -310,12 +311,12 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
|||||||
return increaseUserQuota(id, quota)
|
return increaseUserQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func increaseUserQuota(id int, quota int) (err error) {
|
func increaseUserQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecreaseUserQuota(id int, quota int) (err error) {
|
func DecreaseUserQuota(id int, quota int64) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
}
|
}
|
||||||
@ -326,7 +327,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
return decreaseUserQuota(id, quota)
|
return decreaseUserQuota(id, quota)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decreaseUserQuota(id int, quota int) (err error) {
|
func decreaseUserQuota(id int, quota int64) (err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -336,7 +337,7 @@ func GetRootUserEmail() (email string) {
|
|||||||
return email
|
return email
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
|
||||||
if config.BatchUpdateEnabled {
|
if config.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||||
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
||||||
@ -345,7 +346,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
|||||||
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
@ -357,7 +358,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateUserUsedQuota(id int, quota int) {
|
func updateUserUsedQuota(id int, quota int64) {
|
||||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
|
@ -16,12 +16,12 @@ const (
|
|||||||
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
||||||
)
|
)
|
||||||
|
|
||||||
var batchUpdateStores []map[int]int
|
var batchUpdateStores []map[int]int64
|
||||||
var batchUpdateLocks []sync.Mutex
|
var batchUpdateLocks []sync.Mutex
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
|
batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
|
||||||
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -35,7 +35,7 @@ func InitBatchUpdater() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func addNewRecord(type_ int, id int, value int) {
|
func addNewRecord(type_ int, id int, value int64) {
|
||||||
batchUpdateLocks[type_].Lock()
|
batchUpdateLocks[type_].Lock()
|
||||||
defer batchUpdateLocks[type_].Unlock()
|
defer batchUpdateLocks[type_].Unlock()
|
||||||
if _, ok := batchUpdateStores[type_][id]; !ok {
|
if _, ok := batchUpdateStores[type_][id]; !ok {
|
||||||
@ -50,7 +50,7 @@ func batchUpdate() {
|
|||||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||||
batchUpdateLocks[i].Lock()
|
batchUpdateLocks[i].Lock()
|
||||||
store := batchUpdateStores[i]
|
store := batchUpdateStores[i]
|
||||||
batchUpdateStores[i] = make(map[int]int)
|
batchUpdateStores[i] = make(map[int]int64)
|
||||||
batchUpdateLocks[i].Unlock()
|
batchUpdateLocks[i].Unlock()
|
||||||
// TODO: maybe we can combine updates with same key?
|
// TODO: maybe we can combine updates with same key?
|
||||||
for key, value := range store {
|
for key, value := range store {
|
||||||
@ -68,7 +68,7 @@ func batchUpdate() {
|
|||||||
case BatchUpdateTypeUsedQuota:
|
case BatchUpdateTypeUsedQuota:
|
||||||
updateUserUsedQuota(key, value)
|
updateUserUsedQuota(key, value)
|
||||||
case BatchUpdateTypeRequestCount:
|
case BatchUpdateTypeRequestCount:
|
||||||
updateUserRequestCount(key, value)
|
updateUserRequestCount(key, int(value))
|
||||||
case BatchUpdateTypeChannelUsedQuota:
|
case BatchUpdateTypeChannelUsedQuota:
|
||||||
updateChannelUsedQuota(key, value)
|
updateChannelUsedQuota(key, value)
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package anthropic
|
|||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
||||||
"claude-3-haiku-20240229",
|
"claude-3-haiku-20240307",
|
||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
}
|
}
|
||||||
|
@ -7,4 +7,7 @@ var ModelList = []string{
|
|||||||
"ERNIE-Speed",
|
"ERNIE-Speed",
|
||||||
"ERNIE-Bot-turbo",
|
"ERNIE-Bot-turbo",
|
||||||
"Embedding-V1",
|
"Embedding-V1",
|
||||||
|
"bge-large-zh",
|
||||||
|
"bge-large-en",
|
||||||
|
"tao-8k",
|
||||||
}
|
}
|
||||||
|
9
relay/channel/lingyiwanwu/constants.go
Normal file
9
relay/channel/lingyiwanwu/constants.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package lingyiwanwu
|
||||||
|
|
||||||
|
// https://platform.lingyiwanwu.com/docs
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"yi-34b-chat-0205",
|
||||||
|
"yi-34b-chat-200k",
|
||||||
|
"yi-vl-plus",
|
||||||
|
}
|
65
relay/channel/ollama/adaptor.go
Normal file
65
relay/channel/ollama/adaptor.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channel"
|
||||||
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/util"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||||
|
// https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
|
||||||
|
return fullRequestURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||||
|
channel.SetupCommonRequestHeader(c, req, meta)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
switch relayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
return nil, errors.New("not supported")
|
||||||
|
default:
|
||||||
|
return ConvertRequest(*request), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoRequestHelper(a, c, meta, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
|
if meta.IsStream {
|
||||||
|
err, usage = StreamHandler(c, resp)
|
||||||
|
} else {
|
||||||
|
err, usage = Handler(c, resp)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "ollama"
|
||||||
|
}
|
5
relay/channel/ollama/constants.go
Normal file
5
relay/channel/ollama/constants.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"qwen:0.5b-chat",
|
||||||
|
}
|
178
relay/channel/ollama/main.go
Normal file
178
relay/channel/ollama/main.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||||
|
ollamaRequest := ChatRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Options: &Options{
|
||||||
|
Seed: int(request.Seed),
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
|
PresencePenalty: request.PresencePenalty,
|
||||||
|
},
|
||||||
|
Stream: request.Stream,
|
||||||
|
}
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
ollamaRequest.Messages = append(ollamaRequest.Messages, Message{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &ollamaRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||||
|
choice := openai.TextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: model.Message{
|
||||||
|
Role: response.Message.Role,
|
||||||
|
Content: response.Message.Content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if response.Done {
|
||||||
|
choice.FinishReason = "stop"
|
||||||
|
}
|
||||||
|
fullTextResponse := openai.TextResponse{
|
||||||
|
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: helper.GetTimestamp(),
|
||||||
|
Choices: []openai.TextResponseChoice{choice},
|
||||||
|
Usage: model.Usage{
|
||||||
|
PromptTokens: response.PromptEvalCount,
|
||||||
|
CompletionTokens: response.EvalCount,
|
||||||
|
TotalTokens: response.PromptEvalCount + response.EvalCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||||
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Role = ollamaResponse.Message.Role
|
||||||
|
choice.Delta.Content = ollamaResponse.Message.Content
|
||||||
|
if ollamaResponse.Done {
|
||||||
|
choice.FinishReason = &constant.StopFinishReason
|
||||||
|
}
|
||||||
|
response := openai.ChatCompletionsStreamResponse{
|
||||||
|
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: helper.GetTimestamp(),
|
||||||
|
Model: ollamaResponse.Model,
|
||||||
|
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
var usage model.Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "}\n"); i >= 0 {
|
||||||
|
return i + 2, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := strings.TrimPrefix(scanner.Text(), "}")
|
||||||
|
dataChan <- data + "}"
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
common.SetEventStreamHeaders(c)
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var ollamaResponse ChatResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &ollamaResponse)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if ollamaResponse.EvalCount != 0 {
|
||||||
|
usage.PromptTokens = ollamaResponse.PromptEvalCount
|
||||||
|
usage.CompletionTokens = ollamaResponse.EvalCount
|
||||||
|
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
|
||||||
|
}
|
||||||
|
response := streamResponseOllama2OpenAI(&ollamaResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
|
ctx := context.TODO()
|
||||||
|
var ollamaResponse ChatResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
logger.Debugf(ctx, "ollama response: %s", string(responseBody))
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &ollamaResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if ollamaResponse.Error != "" {
|
||||||
|
return &model.ErrorWithStatusCode{
|
||||||
|
Error: model.Error{
|
||||||
|
Message: ollamaResponse.Error,
|
||||||
|
Type: "ollama_error",
|
||||||
|
Param: "",
|
||||||
|
Code: "ollama_error",
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseOllama2OpenAI(&ollamaResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
37
relay/channel/ollama/model.go
Normal file
37
relay/channel/ollama/model.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Images []string `json:"images,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Messages []Message `json:"messages,omitempty"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Options *Options `json:"options,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatResponse struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
CreatedAt string `json:"created_at,omitempty"`
|
||||||
|
Message Message `json:"message,omitempty"`
|
||||||
|
Response string `json:"response,omitempty"` // for stream response
|
||||||
|
Done bool `json:"done,omitempty"`
|
||||||
|
TotalDuration int `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration int `json:"load_duration,omitempty"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
|
EvalCount int `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration int `json:"eval_duration,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/channel/ai360"
|
"github.com/songquanpeng/one-api/relay/channel/ai360"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/baichuan"
|
"github.com/songquanpeng/one-api/relay/channel/baichuan"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/groq"
|
"github.com/songquanpeng/one-api/relay/channel/groq"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/minimax"
|
"github.com/songquanpeng/one-api/relay/channel/minimax"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/mistral"
|
"github.com/songquanpeng/one-api/relay/channel/mistral"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/moonshot"
|
"github.com/songquanpeng/one-api/relay/channel/moonshot"
|
||||||
@ -18,6 +19,7 @@ var CompatibleChannels = []int{
|
|||||||
common.ChannelTypeMinimax,
|
common.ChannelTypeMinimax,
|
||||||
common.ChannelTypeMistral,
|
common.ChannelTypeMistral,
|
||||||
common.ChannelTypeGroq,
|
common.ChannelTypeGroq,
|
||||||
|
common.ChannelTypeLingYiWanWu,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||||
@ -36,6 +38,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
|||||||
return "mistralai", mistral.ModelList
|
return "mistralai", mistral.ModelList
|
||||||
case common.ChannelTypeGroq:
|
case common.ChannelTypeGroq:
|
||||||
return "groq", groq.ModelList
|
return "groq", groq.ModelList
|
||||||
|
case common.ChannelTypeLingYiWanWu:
|
||||||
|
return "lingyiwanwu", lingyiwanwu.ModelList
|
||||||
default:
|
default:
|
||||||
return "openai", ModelList
|
return "openai", ModelList
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ const (
|
|||||||
APITypeAIProxyLibrary
|
APITypeAIProxyLibrary
|
||||||
APITypeTencent
|
APITypeTencent
|
||||||
APITypeGemini
|
APITypeGemini
|
||||||
|
APITypeOllama
|
||||||
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int {
|
|||||||
apiType = APITypeTencent
|
apiType = APITypeTencent
|
||||||
case common.ChannelTypeGemini:
|
case common.ChannelTypeGemini:
|
||||||
apiType = APITypeGemini
|
apiType = APITypeGemini
|
||||||
|
case common.ChannelTypeOllama:
|
||||||
|
apiType = APITypeOllama
|
||||||
}
|
}
|
||||||
return apiType
|
return apiType
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||||
|
ctx := c.Request.Context()
|
||||||
audioModel := "whisper-1"
|
audioModel := "whisper-1"
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
@ -51,16 +52,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
// groupRatio := common.GetGroupRatio(group)
|
// groupRatio := common.GetGroupRatio(group)
|
||||||
groupRatio := c.GetFloat64("channel_ratio")
|
groupRatio := c.GetFloat64("channel_ratio")
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
var quota int
|
var quota int64
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int64
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case constant.RelayModeAudioSpeech:
|
case constant.RelayModeAudioSpeech:
|
||||||
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
||||||
quota = preConsumedQuota
|
quota = preConsumedQuota
|
||||||
default:
|
default:
|
||||||
preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
|
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
|
||||||
}
|
}
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(ctx, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -185,7 +186,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
quota = openai.CountTokenText(text, audioModel)
|
quota = int64(openai.CountTokenText(text, audioModel))
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
@ -107,18 +107,18 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
|
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 {
|
||||||
preConsumedTokens := config.PreConsumedQuota
|
preConsumedTokens := config.PreConsumedQuota
|
||||||
if textRequest.MaxTokens != 0 {
|
if textRequest.MaxTokens != 0 {
|
||||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens)
|
||||||
}
|
}
|
||||||
return int(float64(preConsumedTokens) * ratio)
|
return int64(float64(preConsumedTokens) * ratio)
|
||||||
}
|
}
|
||||||
|
|
||||||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
|
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
|
||||||
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
||||||
|
|
||||||
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
|
|||||||
return preConsumedQuota, nil
|
return preConsumedQuota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
|
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
quota := 0
|
var quota int64
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
quota = 1
|
quota = 1
|
||||||
}
|
}
|
||||||
@ -168,7 +168,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(meta.UserId)
|
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
|
@ -81,9 +81,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
// groupRatio := common.GetGroupRatio(meta.Group)
|
// groupRatio := common.GetGroupRatio(meta.Group)
|
||||||
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
|
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||||
|
|
||||||
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
@ -127,7 +127,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(meta.UserId)
|
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error update user quota cache: " + err.Error())
|
logger.SysError("error update user quota cache: " + err.Error())
|
||||||
}
|
}
|
||||||
|
@ -77,6 +77,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/channel/aiproxy"
|
"github.com/songquanpeng/one-api/relay/channel/aiproxy"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/anthropic"
|
"github.com/songquanpeng/one-api/relay/channel/anthropic"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/gemini"
|
"github.com/songquanpeng/one-api/relay/channel/gemini"
|
||||||
|
"github.com/songquanpeng/one-api/relay/channel/ollama"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/channel/palm"
|
"github.com/songquanpeng/one-api/relay/channel/palm"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
@ -26,12 +27,14 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &openai.Adaptor{}
|
return &openai.Adaptor{}
|
||||||
case constant.APITypePaLM:
|
case constant.APITypePaLM:
|
||||||
return &palm.Adaptor{}
|
return &palm.Adaptor{}
|
||||||
// case constant.APITypeTencent:
|
// case constant.APITypeTencent:
|
||||||
// return &tencent.Adaptor{}
|
// return &tencent.Adaptor{}
|
||||||
// case constant.APITypeXunfei:
|
// case constant.APITypeXunfei:
|
||||||
// return &xunfei.Adaptor{}
|
// return &xunfei.Adaptor{}
|
||||||
// case constant.APITypeZhipu:
|
// case constant.APITypeZhipu:
|
||||||
// return &zhipu.Adaptor{}
|
// return &zhipu.Adaptor{}
|
||||||
|
case constant.APITypeOllama:
|
||||||
|
return &ollama.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
|
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||||
if preConsumedQuota != 0 {
|
if preConsumedQuota != 0 {
|
||||||
go func(ctx context.Context) {
|
go func(ctx context.Context) {
|
||||||
// return pre-consumed quota
|
// return pre-consumed quota
|
||||||
|
@ -35,10 +35,17 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
|
|||||||
return true
|
return true
|
||||||
case "permission_error":
|
case "permission_error":
|
||||||
return true
|
return true
|
||||||
|
case "forbidden":
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
||||||
|
return true
|
||||||
|
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,20 +155,20 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
|||||||
return fullRequestURL
|
return fullRequestURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||||
// quotaDelta is remaining quota to be consumed
|
// quotaDelta is remaining quota to be consumed
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
err = model.CacheUpdateUserQuota(ctx, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error update user quota cache: " + err.Error())
|
logger.SysError("error update user quota cache: " + err.Error())
|
||||||
}
|
}
|
||||||
// totalQuota is total quota consumed
|
// totalQuota is total quota consumed
|
||||||
if totalQuota >= 0 {
|
if totalQuota >= 0 {
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
func SetDashboardRouter(router *gin.Engine) {
|
func SetDashboardRouter(router *gin.Engine) {
|
||||||
apiRouter := router.Group("/")
|
apiRouter := router.Group("/")
|
||||||
|
apiRouter.Use(middleware.CORS())
|
||||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||||
apiRouter.Use(middleware.TokenAuth())
|
apiRouter.Use(middleware.TokenAuth())
|
||||||
|
@ -95,6 +95,18 @@ export const CHANNEL_OPTIONS = {
|
|||||||
value: 29,
|
value: 29,
|
||||||
color: 'default'
|
color: 'default'
|
||||||
},
|
},
|
||||||
|
30: {
|
||||||
|
key: 30,
|
||||||
|
text: 'Ollama',
|
||||||
|
value: 30,
|
||||||
|
color: 'default'
|
||||||
|
},
|
||||||
|
31: {
|
||||||
|
key: 31,
|
||||||
|
text: '零一万物',
|
||||||
|
value: 31,
|
||||||
|
color: 'default'
|
||||||
|
},
|
||||||
8: {
|
8: {
|
||||||
key: 8,
|
key: 8,
|
||||||
text: '自定义渠道',
|
text: '自定义渠道',
|
||||||
|
@ -166,6 +166,12 @@ const typeConfig = {
|
|||||||
29: {
|
29: {
|
||||||
modelGroup: "groq",
|
modelGroup: "groq",
|
||||||
},
|
},
|
||||||
|
30: {
|
||||||
|
modelGroup: "ollama",
|
||||||
|
},
|
||||||
|
31: {
|
||||||
|
modelGroup: "lingyiwanwu",
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export { defaultConfig, typeConfig };
|
export { defaultConfig, typeConfig };
|
||||||
|
@ -265,7 +265,7 @@ const OtherSetting = () => {
|
|||||||
multiline
|
multiline
|
||||||
maxRows={15}
|
maxRows={15}
|
||||||
id="Footer"
|
id="Footer"
|
||||||
label="公告"
|
label="页脚"
|
||||||
value={inputs.Footer}
|
value={inputs.Footer}
|
||||||
name="Footer"
|
name="Footer"
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
|
@ -31,7 +31,7 @@ const COPY_OPTIONS = [
|
|||||||
url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}',
|
url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}',
|
||||||
encode: false
|
encode: false
|
||||||
},
|
},
|
||||||
{ key: 'ama', text: 'AMA 问天', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
|
{ key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
|
||||||
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
|
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ const COPY_OPTIONS = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
const OPEN_LINK_OPTIONS = [
|
const OPEN_LINK_OPTIONS = [
|
||||||
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
|
{ key: 'ama', text: 'BotGem', value: 'ama' },
|
||||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 26, text: '百川大模型', value: 26, color: 'orange' },
|
{ key: 26, text: '百川大模型', value: 26, color: 'orange' },
|
||||||
{ key: 27, text: 'MiniMax', value: 27, color: 'red' },
|
{ key: 27, text: 'MiniMax', value: 27, color: 'red' },
|
||||||
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
|
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
|
||||||
|
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
|
||||||
|
{ key: 31, text: '零一万物', value: 31, color: 'green' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
Loading…
Reference in New Issue
Block a user