refactor: refactor relay part

This commit is contained in:
JustSong
2024-01-21 22:56:20 +08:00
parent e2ed0399f0
commit b373882814
64 changed files with 1092 additions and 793 deletions

View File

@@ -91,16 +91,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var SyncFrequency = GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout = GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting = GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = GetOrDefaultString("THEME", "default")
var Theme = GetOrDefaultEnvString("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
@@ -127,10 +127,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@@ -199,29 +199,29 @@ const (
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
}

View File

@@ -4,4 +4,4 @@ var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
var SQLiteBusyTimeout = GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -1,6 +1,9 @@
package common
import "encoding/json"
import (
"encoding/json"
"one-api/common/logger"
)
var GroupRatio = map[string]float64{
"default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
SysError("group ratio not found: " + name)
logger.SysError("group ratio not found: " + name)
return 1
}
return ratio

View File

@@ -4,6 +4,7 @@ import (
"flag"
"fmt"
"log"
"one-api/common/logger"
"os"
"path/filepath"
)
@@ -37,7 +38,7 @@ func init() {
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
SessionSecret = os.Getenv("SESSION_SECRET")
}

View File

@@ -1,4 +1,4 @@
package common
package logger
import (
"context"
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"log"
"one-api/common"
"os"
"path/filepath"
"sync"
@@ -25,7 +26,7 @@ var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if *LogDir != "" {
if *common.LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,7 +36,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -55,24 +56,36 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) {
func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a))
}
func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a))
}
func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a))
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
id := ctx.Value(common.RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
@@ -92,8 +105,8 @@ func FatalLog(v ...any) {
}
func LogQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit)
if common.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/common.QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}

View File

@@ -2,6 +2,7 @@ package common
import (
"encoding/json"
"one-api/common/logger"
"strings"
"time"
)
@@ -107,7 +108,7 @@ var ModelRatio = map[string]float64{
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -123,7 +124,7 @@ func GetModelRatio(name string) float64 {
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
logger.SysError("model ratio not found: " + name)
return 30
}
return ratio

View File

@@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
"one-api/common/logger"
"os"
"time"
)
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
SysLog("Redis is enabled")
logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}

View File

@@ -7,6 +7,7 @@ import (
"log"
"math/rand"
"net"
"one-api/common/logger"
"os"
"os/exec"
"runtime"
@@ -184,25 +185,32 @@ func Max(a int, b int) int {
}
}
func GetOrDefault(env string, defaultValue int) int {
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}