diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index e81ab09f..4b5694e7 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -20,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 0 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 13415276..8304de05 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -20,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 0 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 8b1160b4..eb1cbe21 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -23,6 +23,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 0 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/common/config/config.go b/common/config/config.go index 400f65c5..86446506 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/env" ) func init() { @@ -90,14 +90,14 @@ var MessagePusherToken = "" var TurnstileSiteKey = "" var TurnstileSecretKey = "" -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 +var QuotaForNewUser int64 = 0 +var QuotaForInviter int64 = 0 +var QuotaForInvitee int64 = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 +var QuotaRemindThreshold int64 = 1000 +var PreConsumedQuota int64 = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 @@ -108,17 +108,17 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second -var IdleTimeout = helper.GetOrDefaultEnvInt("IDLE_TIMEOUT", 30) // unit is second +var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second +var IdleTimeout = env.Int("IDLE_TIMEOUT", 30) // unit is second -var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") -var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var Theme = env.String("THEME", "default") var ValidThemes = map[string]bool{ "default": true, "berry": true, @@ -127,10 +127,10 @@ var ValidThemes = map[string]bool{ // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -145,8 +145,8 @@ var ( var RateLimitKeyExpirationDuration = 20 * time.Minute -var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) -var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) -var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) -var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) -var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) +var EnableMetric = env.Bool("ENABLE_METRIC", false) +var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) diff --git a/common/constants.go b/common/constants.go index de71bc7a..849bdce7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -69,6 +69,8 @@ const ( ChannelTypeMinimax ChannelTypeMistral ChannelTypeGroq + ChannelTypeOllama + ChannelTypeLingYiWanWu ChannelTypeDummy ) @@ -104,6 +106,8 @@ var ChannelBaseURLs = []string{ "https://api.minimax.chat", // 27 "https://api.mistral.ai", // 28 "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 } const ( diff --git a/common/database.go b/common/database.go index df60bdd5..f2db759f 100644 --- a/common/database.go +++ b/common/database.go @@ -1,10 +1,12 @@ package common -import "github.com/songquanpeng/one-api/common/helper" +import ( + "github.com/songquanpeng/one-api/common/env" +) var UsingSQLite = false var UsingPostgreSQL = false var UsingMySQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/env/helper.go b/common/env/helper.go new file mode 100644 index 00000000..fdb9f827 --- /dev/null +++ b/common/env/helper.go @@ -0,0 +1,42 @@ +package env + +import ( + "os" + "strconv" +) + +func Bool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + +func Int(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + return defaultValue + } + return num +} + +func Float64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + return defaultValue + } + return num +} + +func String(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/helper/helper.go b/common/helper/helper.go index 23578842..db41ac74 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -3,12 +3,10 @@ package helper import ( "fmt" "github.com/google/uuid" - "github.com/songquanpeng/one-api/common/logger" "html/template" "log" "math/rand" "net" - "os" "os/exec" "runtime" "strconv" @@ -187,6 +185,10 @@ func GetTimeString() string { return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) } +func GenRequestID() string { + return GetTimeString() + GetRandomNumberString(8) +} + func Max(a int, b int) int { if a >= b { return a @@ -195,44 +197,6 @@ func Max(a int, b int) int { } } -func GetOrDefaultEnvBool(env string, defaultValue bool) bool { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) == "true" -} - -func GetOrDefaultEnvInt(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.ParseFloat(os.Getenv(env), 64) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - func AssignOrDefault(value string, defaultValue string) string { if len(value) != 0 { return value diff --git a/common/logger/logger.go b/common/logger/logger.go index 41b98ca3..957d8a11 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" "io" "log" "os" @@ -54,7 +56,9 @@ func SysError(s string) { } func Debug(ctx context.Context, msg string) { - logHelper(ctx, loggerDEBUG, msg) + if config.DebugEnabled { + logHelper(ctx, loggerDEBUG, msg) + } } func Info(ctx context.Context, msg string) { @@ -91,6 +95,9 @@ func logHelper(ctx context.Context, level string, msg string) { writer = gin.DefaultWriter } id := ctx.Value(RequestIdKey) + if id == nil { + id = helper.GenRequestID() + } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) if !setupLogWorking { diff --git a/common/model-ratio.go b/common/model-ratio.go index 7de028be..546975a8 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -69,7 +69,7 @@ var ModelRatio = map[string]float64{ "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD, - "claude-3-haiku-20240229": 0.25 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 @@ -78,6 +78,9 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens "ERNIE-Bot-8k": 0.024 * RMB, "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "bge-large-8k": 0.002 * RMB, "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens @@ -130,6 +133,10 @@ var ModelRatio = map[string]float64{ "llama2-7b-2048": 0.1 / 1000 * USD, "mixtral-8x7b-32768": 0.27 / 1000 * USD, "gemma-7b-it": 0.1 / 1000 * USD, + // https://platform.lingyiwanwu.com/docs#-计费单元 + "yi-34b-chat-0205": 2.5 / 1000000 * RMB, + "yi-34b-chat-200k": 12.0 / 1000000 * RMB, + "yi-vl-plus": 6.0 / 1000000 * RMB, } var CompletionRatio = map[string]float64{} diff --git a/common/utils.go b/common/utils.go index 24615225..ecee2c8e 100644 --- a/common/utils.go +++ b/common/utils.go @@ -5,7 +5,7 @@ import ( "github.com/songquanpeng/one-api/common/config" ) -func LogQuota(quota int) string { +func LogQuota(quota int64) string { if config.DisplayInCurrencyEnabled { return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { diff --git a/controller/billing.go b/controller/billing.go index 7317913d..dd518678 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -8,8 +8,8 @@ import ( ) func GetSubscription(c *gin.Context) { - var remainQuota int - var usedQuota int + var remainQuota int64 + var usedQuota int64 var err error var token *model.Token var expiredTime int64 @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { } func GetUsage(c *gin.Context) { - var quota int + var quota int64 var err error var token *model.Token if config.DisplayTokenStatEnabled { diff --git a/controller/channel-test.go b/controller/channel-test.go index e982bc71..5791e1c4 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -30,7 +30,7 @@ import ( func buildTestRequest() *relaymodel.GeneralOpenAIRequest { testRequest := &relaymodel.GeneralOpenAIRequest{ - MaxTokens: 1, + MaxTokens: 2, Stream: false, Model: "gpt-3.5-turbo", } diff --git a/controller/token.go b/controller/token.go index 1e24d741..a94617e1 100644 --- a/controller/token.go +++ b/controller/token.go @@ -234,18 +234,18 @@ func UpdateToken(c *gin.Context) { tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime } if tokenPatch.RemainQuota != nil { - tokenInDB.RemainQuota = *tokenPatch.RemainQuota + tokenInDB.RemainQuota = int64(*tokenPatch.RemainQuota) } if tokenPatch.UnlimitedQuota != nil { tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota } } - tokenInDB.RemainQuota -= tokenPatch.AddUsedQuota - tokenInDB.UsedQuota += tokenPatch.AddUsedQuota + tokenInDB.RemainQuota -= int64(tokenPatch.AddUsedQuota) + tokenInDB.UsedQuota += int64(tokenPatch.AddUsedQuota) if tokenPatch.AddUsedQuota != 0 { - model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(tokenPatch.AddUsedQuota))) + model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(int64(tokenPatch.AddUsedQuota)))) } if err = tokenInDB.Update(); err != nil { diff --git a/docker-compose.yml b/docker-compose.yml index 30edb281..1325a818 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.4' services: one-api: - image: justsong/one-api:latest + image: "${REGISTRY:-docker.io}/justsong/one-api:latest" container_name: one-api restart: always command: --log-dir /app/logs @@ -29,12 +29,12 @@ services: retries: 3 redis: - image: redis:latest + image: "${REGISTRY:-docker.io}/redis:latest" container_name: redis restart: always db: - image: mysql:8.2.0 + image: "${REGISTRY:-docker.io}/mysql:8.2.0" restart: always container_name: mysql volumes: diff --git a/main.go b/main.go index 2bb719f3..9c727152 100644 --- a/main.go +++ b/main.go @@ -32,11 +32,25 @@ func main() { if config.DebugEnabled { logger.SysLog("running in debug mode") } + var err error // Initialize SQL Database - err := model.InitDB() + model.DB, err = model.InitDB("SQL_DSN") if err != nil { logger.FatalLog("failed to initialize database: " + err.Error()) } + if os.Getenv("LOG_SQL_DSN") != "" { + logger.SysLog("using secondary database for table logs") + model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + } + } else { + model.LOG_DB = model.DB + } + err = model.CreateRootAccountIfNeed() + if err != nil { + logger.FatalLog("database init error: " + err.Error()) + } defer func() { err := model.CloseDB() if err != nil { diff --git a/middleware/request-id.go b/middleware/request-id.go index 234a93d8..a4c49ddb 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -9,7 +9,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := helper.GetTimeString() + helper.GetRandomNumberString(8) + id := helper.GenRequestID() c.Set(logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) diff --git a/model/cache.go b/model/cache.go index d0c74f52..e684e12d 100644 --- a/model/cache.go +++ b/model/cache.go @@ -1,6 +1,7 @@ package model import ( + "context" "encoding/json" "fmt" "github.com/Laisky/errors/v2" @@ -71,31 +72,42 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } -func CacheGetUserQuota(id int) (quota int, err error) { +func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + if err != nil { + logger.Error(ctx, "Redis set user quota error: "+err.Error()) + } + return +} + +func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { if !common.RedisEnabled { return GetUserQuota(id) } quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) if err != nil { - quota, err = GetUserQuota(id) - if err != nil { - return 0, err - } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - if err != nil { - logger.SysError("Redis set user quota error: " + err.Error()) - } - return quota, err + return fetchAndUpdateUserQuota(ctx, id) } - quota, err = strconv.Atoi(quotaString) - return quota, err + quota, err = strconv.ParseInt(quotaString, 10, 64) + if err != nil { + return 0, nil + } + if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db + logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) + return fetchAndUpdateUserQuota(ctx, id) + } + return quota, nil } -func CacheUpdateUserQuota(id int) error { +func CacheUpdateUserQuota(ctx context.Context, id int) error { if !common.RedisEnabled { return nil } - quota, err := CacheGetUserQuota(id) + quota, err := CacheGetUserQuota(ctx, id) if err != nil { return err } @@ -103,7 +115,7 @@ func CacheUpdateUserQuota(id int) error { return err } -func CacheDecreaseUserQuota(id int, quota int) error { +func CacheDecreaseUserQuota(id int, quota int64) error { if !common.RedisEnabled { return nil } diff --git a/model/channel.go b/model/channel.go index 605c6d17..fc4905b1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -178,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) { } } -func UpdateChannelUsedQuota(id int, quota int) { +func UpdateChannelUsedQuota(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return @@ -186,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) { updateChannelUsedQuota(id, quota) } -func updateChannelUsedQuota(id int, quota int) { +func updateChannelUsedQuota(id int, quota int64) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { logger.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/log.go b/model/log.go index 9615c237..4409f73e 100644 --- a/model/log.go +++ b/model/log.go @@ -45,13 +45,13 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.SysError("failed to record log: " + err.Error()) } } -func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { return @@ -66,10 +66,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke CompletionTokens: completionTokens, TokenName: tokenName, ModelName: modelName, - Quota: quota, + Quota: int(quota), ChannelId: channelId, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.Error(ctx, "failed to record log: "+err.Error()) } @@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB + tx = LOG_DB } else { - tx = DB.Where("type = ?", logType) + tx = LOG_DB.Where("type = ?", logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB.Where("user_id = ?", userId) + tx = LOG_DB.Where("user_id = ?", userId) } else { - tx = DB.Where("user_id = ? and type = ?", userId, logType) + tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error + err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error + err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { - tx := DB.Table("logs").Select("ifnull(sum(quota),0)") +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { + tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa } func DeleteOldLog(targetTimestamp int64) (int64, error) { - result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error } @@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" } - err = DB.Raw(` + err = LOG_DB.Raw(` SELECT `+groupSelect+`, model_name, count(1) as request_count, sum(quota) as quota, diff --git a/model/main.go b/model/main.go index bda8e943..0ef26c8d 100644 --- a/model/main.go +++ b/model/main.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/driver/mysql" @@ -18,8 +19,9 @@ import ( ) var DB *gorm.DB +var LOG_DB *gorm.DB -func createRootAccountIfNeed() error { +func CreateRootAccountIfNeed() error { var user User //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { @@ -42,9 +44,9 @@ func createRootAccountIfNeed() error { return nil } -func chooseDB() (*gorm.DB, error) { - if os.Getenv("SQL_DSN") != "" { - dsn := os.Getenv("SQL_DSN") +func chooseDB(envName string) (*gorm.DB, error) { + if os.Getenv(envName) != "" { + dsn := os.Getenv(envName) if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL logger.SysLog("using PostgreSQL as database") @@ -72,23 +74,22 @@ func chooseDB() (*gorm.DB, error) { }) } -func InitDB() (err error) { - db, err := chooseDB() +func InitDB(envName string) (db *gorm.DB, err error) { + db, err = chooseDB(envName) if err == nil { if config.DebugSQLEnabled { db = db.Debug() } - DB = db - sqlDB, err := DB.DB() + sqlDB, err := db.DB() if err != nil { - return errors.WithStack(err) + return nil, err } - sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) if !config.IsMasterNode { - return nil + return db, err } if common.UsingMySQL { _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded @@ -96,46 +97,55 @@ func InitDB() (err error) { logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { - return errors.WithStack(err) + return nil, err } err = db.AutoMigrate(&Token{}) if err != nil { - return errors.WithStack(err) + return nil, err } err = db.AutoMigrate(&User{}) if err != nil { - return errors.WithStack(err) + return nil, err } err = db.AutoMigrate(&Option{}) if err != nil { - return errors.WithStack(err) + return nil, err } err = db.AutoMigrate(&Redemption{}) if err != nil { - return errors.WithStack(err) + return nil, err } err = db.AutoMigrate(&Ability{}) if err != nil { - return errors.WithStack(err) + return nil, err } err = db.AutoMigrate(&Log{}) if err != nil { - return errors.WithStack(err) + return nil, err } logger.SysLog("database migrated") - err = createRootAccountIfNeed() - return errors.WithStack(err) + return db, err } else { logger.FatalLog(err) } - return errors.WithStack(err) + return db, err } -func CloseDB() error { - sqlDB, err := DB.DB() +func closeDB(db *gorm.DB) error { + sqlDB, err := db.DB() if err != nil { return errors.WithStack(err) } err = sqlDB.Close() return errors.WithStack(err) } + +func CloseDB() error { + if LOG_DB != DB { + err := closeDB(LOG_DB) + if err != nil { + return err + } + } + return closeDB(DB) +} diff --git a/model/option.go b/model/option.go index e129b9f0..1d1c28b4 100644 --- a/model/option.go +++ b/model/option.go @@ -61,11 +61,11 @@ func InitOptionMap() { config.OptionMap["MessagePusherToken"] = "" config.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSecretKey"] = "" - config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) - config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) - config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) - config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) - config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) + config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) + config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) + config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) + config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() @@ -193,15 +193,15 @@ func updateOptionMap(key string, value string) (err error) { case "TurnstileSecretKey": config.TurnstileSecretKey = value case "QuotaForNewUser": - config.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInviter": - config.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInvitee": - config.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) case "QuotaRemindThreshold": - config.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) case "PreConsumedQuota": - config.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": diff --git a/model/redemption.go b/model/redemption.go index b2493622..47c75d68 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -14,7 +14,7 @@ type Redemption struct { Key string `json:"key" gorm:"type:char(32);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` - Quota int `json:"quota" gorm:"default:100"` + Quota int64 `json:"quota" gorm:"default:100"` CreatedTime int64 `json:"created_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` Count int `json:"count" gorm:"-:all"` // only for api request @@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) { return &redemption, err } -func Redeem(key string, userId int) (quota int, err error) { +func Redeem(key string, userId int) (quota int64, err error) { if key == "" { return 0, errors.New("未提供兑换码") } diff --git a/model/token.go b/model/token.go index 68fbd847..fda4a563 100644 --- a/model/token.go +++ b/model/token.go @@ -21,9 +21,9 @@ type Token struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` AccessedTime int64 `json:"accessed_time" gorm:"bigint"` ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int `json:"remain_quota" gorm:"default:0"` + RemainQuota int64 `json:"remain_quota" gorm:"default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota } func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { @@ -141,7 +141,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, quota int) (err error) { +func IncreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -152,7 +152,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { return increaseTokenQuota(id, quota) } -func increaseTokenQuota(id int, quota int) (err error) { +func increaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -163,7 +163,7 @@ func increaseTokenQuota(id int, quota int) (err error) { return err } -func DecreaseTokenQuota(id int, quota int) (err error) { +func DecreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -174,7 +174,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { return decreaseTokenQuota(id, quota) } -func decreaseTokenQuota(id int, quota int) (err error) { +func decreaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), @@ -185,7 +185,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { return err } -func PreConsumeTokenQuota(tokenId int, quota int) (err error) { +func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -235,7 +235,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { return err } -func PostConsumeTokenQuota(tokenId int, quota int) (err error) { +func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) diff --git a/model/user.go b/model/user.go index 7935823d..26279c5f 100644 --- a/model/user.go +++ b/model/user.go @@ -2,6 +2,8 @@ package model import ( "fmt" + "strings" + "github.com/Laisky/errors/v2" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" @@ -9,7 +11,6 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" - "strings" ) // User if you add sensitive fields, don't forget to clean them in setupLogin function. @@ -26,10 +27,10 @@ type User struct { WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management - Quota int `json:"quota" gorm:"column:quota;type:int;default:0"` - UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota - RequestCount int `json:"request_count" gorm:"column:request_count;type:int;default:0;"` // request number - Group string `json:"group" gorm:"column:group;type:varchar(32);default:'default'"` + Quota int64 `json:"quota" gorm:"type:int;default:0"` + UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota + RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` } @@ -274,12 +275,12 @@ func ValidateAccessToken(token string) (user *User) { return nil } -func GetUserQuota(id int) (quota int, err error) { +func GetUserQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error return quota, err } -func GetUserUsedQuota(id int) (quota int, err error) { +func GetUserUsedQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error return quota, err } @@ -299,7 +300,7 @@ func GetUserGroup(id int) (group string, err error) { return group, err } -func IncreaseUserQuota(id int, quota int) (err error) { +func IncreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -310,12 +311,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { return increaseUserQuota(id, quota) } -func increaseUserQuota(id int, quota int) (err error) { +func increaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } -func DecreaseUserQuota(id int, quota int) (err error) { +func DecreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -326,7 +327,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { return decreaseUserQuota(id, quota) } -func decreaseUserQuota(id int, quota int) (err error) { +func decreaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -336,7 +337,7 @@ func GetRootUserEmail() (email string) { return email } -func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { +func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) @@ -345,7 +346,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { updateUserUsedQuotaAndRequestCount(id, quota, 1) } -func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { +func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), @@ -357,7 +358,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { } } -func updateUserUsedQuota(id int, quota int) { +func updateUserUsedQuota(id int, quota int64) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), diff --git a/model/utils.go b/model/utils.go index d481973a..a55eb4b6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -16,12 +16,12 @@ const ( BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock ) -var batchUpdateStores []map[int]int +var batchUpdateStores []map[int]int64 var batchUpdateLocks []sync.Mutex func init() { for i := 0; i < BatchUpdateTypeCount; i++ { - batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) } } @@ -35,7 +35,7 @@ func InitBatchUpdater() { }() } -func addNewRecord(type_ int, id int, value int) { +func addNewRecord(type_ int, id int, value int64) { batchUpdateLocks[type_].Lock() defer batchUpdateLocks[type_].Unlock() if _, ok := batchUpdateStores[type_][id]; !ok { @@ -50,7 +50,7 @@ func batchUpdate() { for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] - batchUpdateStores[i] = make(map[int]int) + batchUpdateStores[i] = make(map[int]int64) batchUpdateLocks[i].Unlock() // TODO: maybe we can combine updates with same key? for key, value := range store { @@ -68,7 +68,7 @@ func batchUpdate() { case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) case BatchUpdateTypeRequestCount: - updateUserRequestCount(key, value) + updateUserRequestCount(key, int(value)) case BatchUpdateTypeChannelUsedQuota: updateChannelUsedQuota(key, value) } diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go index fcc0c2a3..cadcedc8 100644 --- a/relay/channel/anthropic/constants.go +++ b/relay/channel/anthropic/constants.go @@ -2,7 +2,7 @@ package anthropic var ModelList = []string{ "claude-instant-1.2", "claude-2.0", "claude-2.1", - "claude-3-haiku-20240229", + "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", } diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go index 0fa8f2d6..45a4e901 100644 --- a/relay/channel/baidu/constants.go +++ b/relay/channel/baidu/constants.go @@ -7,4 +7,7 @@ var ModelList = []string{ "ERNIE-Speed", "ERNIE-Bot-turbo", "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", } diff --git a/relay/channel/lingyiwanwu/constants.go b/relay/channel/lingyiwanwu/constants.go new file mode 100644 index 00000000..30000e9d --- /dev/null +++ b/relay/channel/lingyiwanwu/constants.go @@ -0,0 +1,9 @@ +package lingyiwanwu + +// https://platform.lingyiwanwu.com/docs + +var ModelList = []string{ + "yi-34b-chat-0205", + "yi-34b-chat-200k", + "yi-vl-plus", +} diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go new file mode 100644 index 00000000..06c66101 --- /dev/null +++ b/relay/channel/ollama/adaptor.go @@ -0,0 +1,65 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *util.RelayMeta) { + +} + +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + return nil, errors.New("not supported") + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go new file mode 100644 index 00000000..32f82b2a --- /dev/null +++ b/relay/channel/ollama/constants.go @@ -0,0 +1,5 @@ +package ollama + +var ModelList = []string{ + "qwen:0.5b-chat", +} diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go new file mode 100644 index 00000000..7ec646a3 --- /dev/null +++ b/relay/channel/ollama/main.go @@ -0,0 +1,178 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + ollamaRequest := ChatRequest{ + Model: request.Model, + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, + Stream: request.Stream, + } + for _, message := range request.Messages { + ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + return &ollamaRequest +} + +func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: response.Message.Role, + Content: response.Message.Content, + }, + } + if response.Done { + choice.FinishReason = "stop" + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: model.Usage{ + PromptTokens: response.PromptEvalCount, + CompletionTokens: response.EvalCount, + TotalTokens: response.PromptEvalCount + response.EvalCount, + }, + } + return &fullTextResponse +} + +func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Role = ollamaResponse.Message.Role + choice.Delta.Content = ollamaResponse.Message.Content + if ollamaResponse.Done { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: ollamaResponse.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "}\n"); i >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + dataChan <- data + "}" + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + response := streamResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := context.TODO() + var ollamaResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "ollama response: %s", string(responseBody)) + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go new file mode 100644 index 00000000..a8ef1ffc --- /dev/null +++ b/relay/channel/ollama/model.go @@ -0,0 +1,37 @@ +package ollama + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Stream bool `json:"stream"` + Options *Options `json:"options,omitempty"` +} + +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message Message `json:"message,omitempty"` + Response string `json:"response,omitempty"` // for stream response + Done bool `json:"done,omitempty"` + TotalDuration int `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go index 767eec4b..e4951a34 100644 --- a/relay/channel/openai/compatible.go +++ b/relay/channel/openai/compatible.go @@ -5,6 +5,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/groq" + "github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" "github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" @@ -18,6 +19,7 @@ var CompatibleChannels = []int{ common.ChannelTypeMinimax, common.ChannelTypeMistral, common.ChannelTypeGroq, + common.ChannelTypeLingYiWanWu, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -36,6 +38,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "mistralai", mistral.ModelList case common.ChannelTypeGroq: return "groq", groq.ModelList + case common.ChannelTypeLingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList default: return "openai", ModelList } diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index d2184dac..b249f6a2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -15,6 +15,7 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini + APITypeOllama APITypeDummy // this one is only for count, do not add any channel after this ) @@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeTencent case common.ChannelTypeGemini: apiType = APITypeGemini + case common.ChannelTypeOllama: + apiType = APITypeOllama } return apiType } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index beea6184..96f40e7f 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -23,6 +23,7 @@ import ( ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() audioModel := "whisper-1" tokenId := c.GetInt("token_id") @@ -51,16 +52,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus // groupRatio := common.GetGroupRatio(group) groupRatio := c.GetFloat64("channel_ratio") ratio := modelRatio * groupRatio - var quota int - var preConsumedQuota int + var quota int64 + var preConsumedQuota int64 switch relayMode { case constant.RelayModeAudioSpeech: - preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: - preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) + preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) } - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -185,7 +186,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - quota = openai.CountTokenText(text, audioModel) + quota = int64(openai.CountTokenText(text, audioModel)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index a121a429..71dd653e 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -107,18 +107,18 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int return 0 } -func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { +func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens + preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) } - return int(float64(preConsumedTokens) * ratio) + return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } - quota := 0 + var quota int64 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } @@ -168,7 +168,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } diff --git a/relay/controller/image.go b/relay/controller/image.go index 5e946c81..f5e4e74f 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -81,9 +81,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus // groupRatio := common.GetGroupRatio(meta.Group) groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int(ratio*imageCostRatio*1000) * imageRequest.N + quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -127,7 +127,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } diff --git a/relay/controller/text.go b/relay/controller/text.go index 5dc07d23..282e8f25 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -77,6 +77,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } + logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) requestBody = bytes.NewBuffer(jsonData) } diff --git a/relay/helper/main.go b/relay/helper/main.go index 6aa70e88..18bbe51a 100644 --- a/relay/helper/main.go +++ b/relay/helper/main.go @@ -5,6 +5,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/aiproxy" "github.com/songquanpeng/one-api/relay/channel/anthropic" "github.com/songquanpeng/one-api/relay/channel/gemini" + "github.com/songquanpeng/one-api/relay/channel/ollama" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/palm" "github.com/songquanpeng/one-api/relay/constant" @@ -26,12 +27,14 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypePaLM: return &palm.Adaptor{} - // case constant.APITypeTencent: - // return &tencent.Adaptor{} - // case constant.APITypeXunfei: - // return &xunfei.Adaptor{} - // case constant.APITypeZhipu: - // return &zhipu.Adaptor{} + // case constant.APITypeTencent: + // return &tencent.Adaptor{} + // case constant.APITypeXunfei: + // return &xunfei.Adaptor{} + // case constant.APITypeZhipu: + // return &zhipu.Adaptor{} + case constant.APITypeOllama: + return &ollama.Adaptor{} } return nil } diff --git a/relay/util/billing.go b/relay/util/billing.go index 1e2b09ea..495d011e 100644 --- a/relay/util/billing.go +++ b/relay/util/billing.go @@ -6,7 +6,7 @@ import ( "github.com/songquanpeng/one-api/model" ) -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { if preConsumedQuota != 0 { go func(ctx context.Context) { // return pre-consumed quota diff --git a/relay/util/common.go b/relay/util/common.go index 416121e8..518d0b00 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -35,10 +35,17 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { return true case "permission_error": return true + case "forbidden": + return true } if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { + return true + } return false } @@ -148,20 +155,20 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { // quotaDelta is remaining quota to be consumed err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(ctx, userId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } // totalQuota is total quota consumed if totalQuota >= 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) model.UpdateChannelUsedQuota(channelId, totalQuota) } diff --git a/router/dashboard.go b/router/dashboard.go index 0b539d44..5952d698 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -9,6 +9,7 @@ import ( func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") + apiRouter.Use(middleware.CORS()) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.TokenAuth()) diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 8e9fc97c..06597b93 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -95,6 +95,18 @@ export const CHANNEL_OPTIONS = { value: 29, color: 'default' }, + 30: { + key: 30, + text: 'Ollama', + value: 30, + color: 'default' + }, + 31: { + key: 31, + text: '零一万物', + value: 31, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 897db189..8dfe77a4 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -166,6 +166,12 @@ const typeConfig = { 29: { modelGroup: "groq", }, + 30: { + modelGroup: "ollama", + }, + 31: { + modelGroup: "lingyiwanwu", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/berry/src/views/Setting/component/OtherSetting.js b/web/berry/src/views/Setting/component/OtherSetting.js index 01f92f77..426b8c81 100644 --- a/web/berry/src/views/Setting/component/OtherSetting.js +++ b/web/berry/src/views/Setting/component/OtherSetting.js @@ -265,7 +265,7 @@ const OtherSetting = () => { multiline maxRows={15} id="Footer" - label="公告" + label="页脚" value={inputs.Footer} name="Footer" onChange={handleInputChange} diff --git a/web/berry/src/views/Token/component/TableRow.js b/web/berry/src/views/Token/component/TableRow.js index 19594b4c..2753764c 100644 --- a/web/berry/src/views/Token/component/TableRow.js +++ b/web/berry/src/views/Token/component/TableRow.js @@ -31,7 +31,7 @@ const COPY_OPTIONS = [ url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}', encode: false }, - { key: 'ama', text: 'AMA 问天', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true }, + { key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true }, { key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true } ]; diff --git a/web/default/src/components/TokensTable.js b/web/default/src/components/TokensTable.js index 295996aa..5170e765 100644 --- a/web/default/src/components/TokensTable.js +++ b/web/default/src/components/TokensTable.js @@ -12,7 +12,7 @@ const COPY_OPTIONS = [ ]; const OPEN_LINK_OPTIONS = [ - { key: 'ama', text: 'AMA 问天', value: 'ama' }, + { key: 'ama', text: 'BotGem', value: 'ama' }, { key: 'opencat', text: 'OpenCat', value: 'opencat' }, ]; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index f6db46c3..8d536e58 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -15,6 +15,8 @@ export const CHANNEL_OPTIONS = [ { key: 26, text: '百川大模型', value: 26, color: 'orange' }, { key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, + { key: 31, text: '零一万物', value: 31, color: 'green' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },