diff --git a/common/constants.go b/common/constants.go index 9ee791df..a83de9c1 100644 --- a/common/constants.go +++ b/common/constants.go @@ -91,16 +91,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second -var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") -var Theme = GetOrDefaultString("THEME", "default") +var Theme = GetOrDefaultEnvString("THEME", "default") var ValidThemes = map[string]bool{ "default": true, "berry": true, @@ -127,10 +127,10 @@ var ( // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -199,29 +199,29 @@ const ( ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", //23 - "", //24 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 } diff --git a/common/database.go b/common/database.go index 76f2cd55..8f659b57 100644 --- a/common/database.go +++ b/common/database.go @@ -4,4 +4,4 @@ var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/group-ratio.go b/common/group-ratio.go index 1ec73c78..86e9e1f6 100644 --- a/common/group-ratio.go +++ b/common/group-ratio.go @@ -1,6 +1,9 @@ package common -import "encoding/json" +import ( + "encoding/json" + "one-api/common/logger" +) var GroupRatio = map[string]float64{ "default": 1, @@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error { func GetGroupRatio(name string) float64 { ratio, ok := GroupRatio[name] if !ok { - SysError("group ratio not found: " + name) + logger.SysError("group ratio not found: " + name) return 1 } return ratio diff --git a/common/init.go b/common/init.go index 12df5f51..9735c5b4 100644 --- a/common/init.go +++ b/common/init.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "one-api/common/logger" "os" "path/filepath" ) @@ -37,7 +38,7 @@ func init() { if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") == "random_string" { - SysError("SESSION_SECRET is set to an example value, please change it to a random string.") + logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") } else { SessionSecret = os.Getenv("SESSION_SECRET") } diff --git a/common/logger.go b/common/logger/logger.go similarity index 71% rename from common/logger.go rename to common/logger/logger.go index 61627217..4386bc6c 100644 --- a/common/logger.go +++ b/common/logger/logger.go @@ -1,4 +1,4 @@ -package common +package logger import ( "context" @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "io" "log" + "one-api/common" "os" "path/filepath" "sync" @@ -25,7 +26,7 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { - if *LogDir != "" { + if *common.LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") @@ -35,7 +36,7 @@ func SetupLogger() { setupLogLock.Unlock() setupLogWorking = false }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -55,24 +56,36 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } -func LogInfo(ctx context.Context, msg string) { +func Info(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } -func LogWarn(ctx context.Context, msg string) { +func Warn(ctx context.Context, msg string) { logHelper(ctx, loggerWarn, msg) } -func LogError(ctx context.Context, msg string) { +func Error(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func Infof(ctx context.Context, format string, a ...any) { + Info(ctx, fmt.Sprintf(format, a)) +} + +func Warnf(ctx context.Context, format string, a ...any) { + Warn(ctx, fmt.Sprintf(format, a)) +} + +func Errorf(ctx context.Context, format string, a ...any) { + Error(ctx, fmt.Sprintf(format, a)) +} + func logHelper(ctx context.Context, level string, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(RequestIdKey) + id := ctx.Value(common.RequestIdKey) now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) logCount++ // we don't need accurate count, so no lock here @@ -92,8 +105,8 @@ func FatalLog(v ...any) { } func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) + if common.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) } else { return fmt.Sprintf("%d 点额度", quota) } diff --git a/common/model-ratio.go b/common/model-ratio.go index 97cb060d..9f31e0d7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -2,6 +2,7 @@ package common import ( "encoding/json" + "one-api/common/logger" "strings" "time" ) @@ -107,7 +108,7 @@ var ModelRatio = map[string]float64{ func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -123,7 +124,7 @@ func GetModelRatio(name string) float64 { } ratio, ok := ModelRatio[name] if !ok { - SysError("model ratio not found: " + name) + logger.SysError("model ratio not found: " + name) return 30 } return ratio diff --git a/common/redis.go b/common/redis.go index 12c477b8..ed3fcd9d 100644 --- a/common/redis.go +++ b/common/redis.go @@ -3,6 +3,7 @@ package common import ( "context" "github.com/go-redis/redis/v8" + "one-api/common/logger" "os" "time" ) @@ -14,18 +15,18 @@ var RedisEnabled = true func InitRedisClient() (err error) { if os.Getenv("REDIS_CONN_STRING") == "" { RedisEnabled = false - SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } if os.Getenv("SYNC_FREQUENCY") == "" { RedisEnabled = false - SysLog("SYNC_FREQUENCY not set, Redis is disabled") + logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil } - SysLog("Redis is enabled") + logger.SysLog("Redis is enabled") opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } RDB = redis.NewClient(opt) @@ -34,7 +35,7 @@ func InitRedisClient() (err error) { _, err = RDB.Ping(ctx).Result() if err != nil { - FatalLog("Redis ping test failed: " + err.Error()) + logger.FatalLog("Redis ping test failed: " + err.Error()) } return err } @@ -42,7 +43,7 @@ func InitRedisClient() (err error) { func ParseRedisOption() *redis.Options { opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } return opt } diff --git a/common/utils.go b/common/utils.go index 9a7038e2..4e3312f9 100644 --- a/common/utils.go +++ b/common/utils.go @@ -7,6 +7,7 @@ import ( "log" "math/rand" "net" + "one-api/common/logger" "os" "os/exec" "runtime" @@ -184,25 +185,32 @@ func Max(a int, b int) int { } } -func GetOrDefault(env string, defaultValue int) int { +func GetOrDefaultEnvInt(env string, defaultValue int) int { if env == "" || os.Getenv(env) == "" { return defaultValue } num, err := strconv.Atoi(os.Getenv(env)) if err != nil { - SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) + logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) return defaultValue } return num } -func GetOrDefaultString(env string, defaultValue string) string { +func GetOrDefaultEnvString(env string, defaultValue string) string { if env == "" || os.Getenv(env) == "" { return defaultValue } return os.Getenv(env) } +func AssignOrDefault(value string, defaultValue string) string { + if len(value) != 0 { + return value + } + return defaultValue +} + func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 29346cde..61a899a4 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/relay/util" "strconv" @@ -339,8 +340,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("updating all channels") + logger.SysLog("updating all channels") _ = updateAllChannelsBalance() - common.SysLog("channels update done") + logger.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index f64f0ee3..73ff6bb2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" "one-api/relay/util" @@ -155,7 +156,7 @@ func notifyRootUser(subject string, content string) { } err := common.SendEmail(subject, common.RootUserEmail, content) if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } @@ -221,7 +222,7 @@ func testAllChannels(notify bool) error { if notify { err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } }() @@ -247,8 +248,8 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") + logger.SysLog("testing all channels") _ = testAllChannels(false) - common.SysLog("channel test finished") + logger.SysLog("channel test finished") } } diff --git a/controller/github.go b/controller/github.go index ee995379..68692b9d 100644 --- a/controller/github.go +++ b/controller/github.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "strconv" "time" @@ -46,7 +47,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -62,7 +63,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() diff --git a/controller/relay.go b/controller/relay.go index 198d6c9a..e390ae75 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,43 +2,21 @@ package controller import ( "fmt" + "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "one-api/relay/controller" "one-api/relay/util" "strconv" - "strings" - - "github.com/gin-gonic/gin" ) // https://platform.openai.com/docs/api-reference/chat func Relay(c *gin.Context) { - relayMode := constant.RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - relayMode = constant.RelayModeChatCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - relayMode = constant.RelayModeCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - relayMode = constant.RelayModeEmbeddings - } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - relayMode = constant.RelayModeEmbeddings - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = constant.RelayModeModerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - relayMode = constant.RelayModeImagesGenerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { - relayMode = constant.RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - relayMode = constant.RelayModeAudioSpeech - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - relayMode = constant.RelayModeAudioTranscription - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - relayMode = constant.RelayModeAudioTranslation - } + relayMode := constant.Path2RelayMode(c.Request.URL.Path) var err *openai.ErrorWithStatusCode switch relayMode { case constant.RelayModeImagesGenerations: @@ -71,7 +49,7 @@ func Relay(c *gin.Context) { }) } channelId := c.GetInt("channel_id") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors if util.ShouldDisableChannel(&err.Error, err.StatusCode) { channelId := c.GetInt("channel_id") diff --git a/controller/user.go b/controller/user.go index 174300ed..d39bba3b 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "strconv" "time" @@ -409,7 +410,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/main.go b/main.go index 28a41287..9e4a88f2 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "one-api/common" + "one-api/common/logger" "one-api/controller" "one-api/middleware" "one-api/model" @@ -20,42 +21,42 @@ import ( var buildFS embed.FS func main() { - common.SetupLogger() - common.SysLog(fmt.Sprintf("One API %s started", common.Version)) + logger.SetupLogger() + logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } if common.DebugEnabled { - common.SysLog("running in debug mode") + logger.SysLog("running in debug mode") } // Initialize SQL Database err := model.InitDB() if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) + logger.FatalLog("failed to initialize database: " + err.Error()) } defer func() { err := model.CloseDB() if err != nil { - common.FatalLog("failed to close database: " + err.Error()) + logger.FatalLog("failed to close database: " + err.Error()) } }() // Initialize Redis err = common.InitRedisClient() if err != nil { - common.FatalLog("failed to initialize Redis: " + err.Error()) + logger.FatalLog("failed to initialize Redis: " + err.Error()) } // Initialize options model.InitOptionMap() - common.SysLog(fmt.Sprintf("using theme %s", common.Theme)) + logger.SysLog(fmt.Sprintf("using theme %s", common.Theme)) if common.RedisEnabled { // for compatibility with old versions common.MemoryCacheEnabled = true } if common.MemoryCacheEnabled { - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + logger.SysLog("memory cache enabled") + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) model.InitChannelCache() } if common.MemoryCacheEnabled { @@ -65,20 +66,20 @@ func main() { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } openai.InitTokenEncoders() @@ -101,6 +102,6 @@ func main() { } err = server.Run(":" + port) if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + logger.FatalLog("failed to start HTTP server: " + err.Error()) } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 81338130..6b607d68 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "strconv" "strings" @@ -69,7 +70,7 @@ func Distribute() func(c *gin.Context) { if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } abortWithMessage(c, http.StatusServiceUnavailable, message) diff --git a/middleware/recover.go b/middleware/recover.go index 8338a514..9d3edc27 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/logger" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..6f295864 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/logger" ) type turnstileCheckResponse struct { @@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 536125cc..31620bf2 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -3,6 +3,7 @@ package middleware import ( "github.com/gin-gonic/gin" "one-api/common" + "one-api/common/logger" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -13,5 +14,5 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { }, }) c.Abort() - common.LogError(c.Request.Context(), message) + logger.Error(c.Request.Context(), message) } diff --git a/model/cache.go b/model/cache.go index c6d0c70a..eaed5bba 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/common/logger" "sort" "strconv" "strings" @@ -42,7 +43,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { } err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set token error: " + err.Error()) + logger.SysError("Redis set token error: " + err.Error()) } return &token, nil } @@ -62,7 +63,7 @@ func CacheGetUserGroup(id int) (group string, err error) { } err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user group error: " + err.Error()) + logger.SysError("Redis set user group error: " + err.Error()) } } return group, err @@ -80,7 +81,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { } err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user quota error: " + err.Error()) + logger.SysError("Redis set user quota error: " + err.Error()) } return quota, err } @@ -127,7 +128,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { } err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) + logger.SysError("Redis set user enabled error: " + err.Error()) } return userEnabled, err } @@ -178,13 +179,13 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels channelSyncLock.Unlock() - common.SysLog("channels synced from database") + logger.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") + logger.SysLog("syncing channels from database") InitChannelCache() } } diff --git a/model/channel.go b/model/channel.go index 7e7b42e6..d89d1666 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,11 @@ package model import ( + "encoding/json" + "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/logger" ) type Channel struct { @@ -86,11 +89,17 @@ func (channel *Channel) GetBaseURL() string { return *channel.BaseURL } -func (channel *Channel) GetModelMapping() string { - if channel.ModelMapping == nil { - return "" +func (channel *Channel) GetModelMapping() map[string]string { + if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { + return nil } - return *channel.ModelMapping + modelMapping := make(map[string]string) + err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) + if err != nil { + logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) + return nil + } + return modelMapping } func (channel *Channel) Insert() error { @@ -120,7 +129,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + logger.SysError("failed to update response time: " + err.Error()) } } @@ -130,7 +139,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + logger.SysError("failed to update balance: " + err.Error()) } } @@ -147,11 +156,11 @@ func (channel *Channel) Delete() error { func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + logger.SysError("failed to update ability status: " + err.Error()) } err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + logger.SysError("failed to update channel status: " + err.Error()) } } @@ -166,7 +175,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + logger.SysError("failed to update channel used quota: " + err.Error()) } } diff --git a/model/log.go b/model/log.go index 06085acf..728c4b17 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "one-api/common" + "one-api/common/logger" "gorm.io/gorm" ) @@ -44,12 +45,12 @@ func RecordLog(userId int, logType int, content string) { } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + logger.SysError("failed to record log: " + err.Error()) } } func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { - common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) + logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -68,7 +69,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke } err := DB.Create(log).Error if err != nil { - common.LogError(ctx, "failed to record log: "+err.Error()) + logger.Error(ctx, "failed to record log: "+err.Error()) } } diff --git a/model/main.go b/model/main.go index 9723e638..0b9c4f2b 100644 --- a/model/main.go +++ b/model/main.go @@ -7,6 +7,7 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" "one-api/common" + "one-api/common/logger" "os" "strings" "time" @@ -18,7 +19,7 @@ func createRootAccountIfNeed() error { var user User //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -42,7 +43,7 @@ func chooseDB() (*gorm.DB, error) { dsn := os.Getenv("SQL_DSN") if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL - common.SysLog("using PostgreSQL as database") + logger.SysLog("using PostgreSQL as database") common.UsingPostgreSQL = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, @@ -52,13 +53,13 @@ func chooseDB() (*gorm.DB, error) { }) } // Use MySQL - common.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database") return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } // Use SQLite - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ @@ -77,14 +78,14 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(common.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(common.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) if !common.IsMasterNode { return nil } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { return err @@ -113,11 +114,11 @@ func InitDB() (err error) { if err != nil { return err } - common.SysLog("database migrated") + logger.SysLog("database migrated") err = createRootAccountIfNeed() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } diff --git a/model/option.go b/model/option.go index 20575c9a..80abff20 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/logger" "strconv" "strings" "time" @@ -82,7 +83,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + logger.SysError("failed to update option map: " + err.Error()) } } } @@ -90,7 +91,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") + logger.SysLog("syncing options from database") loadOptionsFromDatabase() } } diff --git a/model/redemption.go b/model/redemption.go index f16412b5..ba1e1077 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/logger" ) type Redemption struct { @@ -75,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", logger.LogQuota(redemption.Quota))) return redemption.Quota, nil } diff --git a/model/token.go b/model/token.go index 2e53ac0b..570de47d 100644 --- a/model/token.go +++ b/model/token.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/logger" ) type Token struct { @@ -39,7 +40,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err != nil { - common.SysError("CacheGetTokenByKey failed: " + err.Error()) + logger.SysError("CacheGetTokenByKey failed: " + err.Error()) if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("无效的令牌") } @@ -58,7 +59,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return nil, errors.New("该令牌已过期") @@ -69,7 +70,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return nil, errors.New("该令牌额度已用尽") @@ -202,7 +203,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { go func() { email, err := GetUserEmail(token.UserId) if err != nil { - common.SysError("failed to fetch user email: " + err.Error()) + logger.SysError("failed to fetch user email: " + err.Error()) } prompt := "您的额度即将用尽" if noMoreQuota { @@ -213,7 +214,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { err = common.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { - common.SysError("failed to send email" + err.Error()) + logger.SysError("failed to send email" + err.Error()) } } }() diff --git a/model/user.go b/model/user.go index 1c2c0a75..17f94d9f 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/logger" "strings" ) @@ -97,16 +98,16 @@ func (user *User) Insert(inviterId int) error { return result.Error } if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { _ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) } } return nil @@ -232,7 +233,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + logger.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -341,7 +342,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + logger.SysError("failed to update user used quota and request count: " + err.Error()) } } @@ -352,14 +353,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + logger.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + logger.SysError("failed to update user request count: " + err.Error()) } } diff --git a/model/utils.go b/model/utils.go index 1c28340b..e4797a78 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/logger" "sync" "time" ) @@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) { } func batchUpdate() { - common.SysLog("batch update started") + logger.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -57,12 +58,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + logger.SysError("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + logger.SysError("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -73,5 +74,5 @@ func batchUpdate() { } } } - common.SysLog("batch update finished") + logger.SysLog("batch update finished") } diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go new file mode 100644 index 00000000..44b6f58d --- /dev/null +++ b/relay/channel/aiproxy/adaptor.go @@ -0,0 +1,22 @@ +package aiproxy + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/aiproxy/main.go b/relay/channel/aiproxy/main.go index bee4d9d3..63fef55e 100644 --- a/relay/channel/aiproxy/main.go +++ b/relay/channel/aiproxy/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "strconv" @@ -122,7 +123,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var AIProxyLibraryResponse LibraryStreamResponse err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if len(AIProxyLibraryResponse.Documents) != 0 { @@ -131,7 +132,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -140,7 +141,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus response := documentsAIProxyLibrary(documents) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go new file mode 100644 index 00000000..49022cfc --- /dev/null +++ b/relay/channel/ali/adaptor.go @@ -0,0 +1,22 @@ +package ali + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index f45a515a..c5ada0d7 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "strings" ) @@ -185,7 +186,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var aliResponse ChatResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -198,7 +199,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go new file mode 100644 index 00000000..55577228 --- /dev/null +++ b/relay/channel/anthropic/adaptor.go @@ -0,0 +1,22 @@ +package anthropic + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go index a4272d7b..006779b2 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/channel/anthropic/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "strings" ) @@ -125,7 +126,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var claudeResponse Response err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } responseText += claudeResponse.Completion @@ -134,7 +135,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus response.Created = createdTime jsonStr, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go new file mode 100644 index 00000000..498b664a --- /dev/null +++ b/relay/channel/baidu/adaptor.go @@ -0,0 +1,22 @@ +package baidu + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go index 47969492..f5b98155 100644 --- a/relay/channel/baidu/main.go +++ b/relay/channel/baidu/main.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "one-api/relay/util" @@ -19,49 +20,49 @@ import ( // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 -type BaiduTokenResponse struct { +type TokenResponse struct { ExpiresIn int `json:"expires_in"` AccessToken string `json:"access_token"` } -type BaiduMessage struct { +type Message struct { Role string `json:"role"` Content string `json:"content"` } -type BaiduChatRequest struct { - Messages []BaiduMessage `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` +type ChatRequest struct { + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` } -type BaiduError struct { +type Error struct { ErrorCode int `json:"error_code"` ErrorMsg string `json:"error_msg"` } var baiduTokenStore sync.Map -func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest { - messages := make([]BaiduMessage, 0, len(request.Messages)) +func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { + messages := make([]Message, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, BaiduMessage{ + messages = append(messages, Message{ Role: "user", Content: message.StringContent(), }) - messages = append(messages, BaiduMessage{ + messages = append(messages, Message{ Role: "assistant", Content: "Okay", }) } else { - messages = append(messages, BaiduMessage{ + messages = append(messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &BaiduChatRequest{ + return &ChatRequest{ Messages: messages, Stream: request.Stream, } @@ -160,7 +161,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var baiduResponse ChatStreamResponse err := json.Unmarshal([]byte(data), &baiduResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -171,7 +172,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus response := streamResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go index caaebafb..e182f5dd 100644 --- a/relay/channel/baidu/model.go +++ b/relay/channel/baidu/model.go @@ -13,7 +13,7 @@ type ChatResponse struct { IsTruncated bool `json:"is_truncated"` NeedClearHistory bool `json:"need_clear_history"` Usage openai.Usage `json:"usage"` - BaiduError + Error } type ChatStreamResponse struct { @@ -38,7 +38,7 @@ type EmbeddingResponse struct { Created int64 `json:"created"` Data []EmbeddingData `json:"data"` Usage openai.Usage `json:"usage"` - BaiduError + Error } type AccessToken struct { diff --git a/relay/channel/google/adaptor.go b/relay/channel/google/adaptor.go new file mode 100644 index 00000000..b328db32 --- /dev/null +++ b/relay/channel/google/adaptor.go @@ -0,0 +1,22 @@ +package google + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/google/gemini.go b/relay/channel/google/gemini.go index f49caadf..0f4e606c 100644 --- a/relay/channel/google/gemini.go +++ b/relay/channel/google/gemini.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/common/image" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "strings" @@ -237,7 +238,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus } jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/google/palm.go b/relay/channel/google/palm.go index 77d8cbd6..c2518a07 100644 --- a/relay/channel/google/palm.go +++ b/relay/channel/google/palm.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" ) @@ -78,20 +79,20 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) stopChan <- true return } err = resp.Body.Close() if err != nil { - common.SysError("error closing stream response: " + err.Error()) + logger.SysError("error closing stream response: " + err.Error()) stopChan <- true return } var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -103,7 +104,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) stopChan <- true return } diff --git a/relay/channel/interface.go b/relay/channel/interface.go new file mode 100644 index 00000000..7a0fcbd3 --- /dev/null +++ b/relay/channel/interface.go @@ -0,0 +1,15 @@ +package channel + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor interface { + GetRequestURL() string + Auth(c *gin.Context) error + ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) + DoRequest(request *openai.GeneralOpenAIRequest) error + DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go new file mode 100644 index 00000000..cc302611 --- /dev/null +++ b/relay/channel/openai/adaptor.go @@ -0,0 +1,21 @@ +package openai + +import ( + "github.com/gin-gonic/gin" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index 848a6fa4..5f464249 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/constant" "strings" ) @@ -46,7 +47,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi var streamResponse ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) continue // just ignore the error } for _, choice := range streamResponse.Choices { @@ -56,7 +57,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) continue } for _, choice := range streamResponse.Choices { diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index c831ce19..937fb424 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -207,6 +207,11 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +type UsageOrResponseText struct { + *Usage + ResponseText string +} + type Error struct { Message string `json:"message"` Type string `json:"type"` diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go index 4b40b228..b398c220 100644 --- a/relay/channel/openai/token.go +++ b/relay/channel/openai/token.go @@ -7,6 +7,7 @@ import ( "math" "one-api/common" "one-api/common/image" + "one-api/common/logger" "strings" ) @@ -15,15 +16,15 @@ var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var defaultTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { - common.SysLog("initializing token encoders") + logger.SysLog("initializing token encoders") gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) } defaultTokenEncoder = gpt35TokenEncoder gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } for model, _ := range common.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { @@ -34,7 +35,7 @@ func InitTokenEncoders() { tokenEncoderMap[model] = nil } } - common.SysLog("token encoders initialized") + logger.SysLog("token encoders initialized") } func getTokenEncoder(model string) *tiktoken.Tiktoken { @@ -45,7 +46,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { if ok { tokenEncoder, err := tiktoken.EncodingForModel(model) if err != nil { - common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) tokenEncoder = defaultTokenEncoder } tokenEncoderMap[model] = tokenEncoder @@ -99,7 +100,7 @@ func CountTokenMessages(messages []Message, model string) int { } imageTokens, err := countImageTokens(url, detail) if err != nil { - common.SysError("error counting image tokens: " + err.Error()) + logger.SysError("error counting image tokens: " + err.Error()) } else { tokenNum += imageTokens } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go new file mode 100644 index 00000000..e9f86aff --- /dev/null +++ b/relay/channel/tencent/adaptor.go @@ -0,0 +1,22 @@ +package tencent + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index 60e275a9..9203249a 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -12,6 +12,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "sort" @@ -131,7 +132,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var TencentResponse ChatResponse err := json.Unmarshal([]byte(data), &TencentResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } response := streamResponseTencent2OpenAI(&TencentResponse) @@ -140,7 +141,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus } jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go new file mode 100644 index 00000000..89e58485 --- /dev/null +++ b/relay/channel/xunfei/adaptor.go @@ -0,0 +1,22 @@ +package xunfei + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 1cc0b664..1c55cc09 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "strings" @@ -140,7 +141,7 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -215,20 +216,20 @@ func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, for { _, msg, err := conn.ReadMessage() if err != nil { - common.SysError("error reading stream response: " + err.Error()) + logger.SysError("error reading stream response: " + err.Error()) break } var response ChatResponse err = json.Unmarshal(msg, &response) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { - common.SysError("error closing websocket connection: " + err.Error()) + logger.SysError("error closing websocket connection: " + err.Error()) } break } @@ -247,7 +248,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, } if apiVersion == "" { apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + logger.SysLog("api_version not found, use default: " + apiVersion) } domain := "general" if apiVersion != "v1.1" { diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go new file mode 100644 index 00000000..6d901bc3 --- /dev/null +++ b/relay/channel/zhipu/adaptor.go @@ -0,0 +1,22 @@ +package zhipu + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/relay/channel/openai" +) + +type Adaptor struct { +} + +func (a *Adaptor) Auth(c *gin.Context) error { + return nil +} + +func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) { + return nil, nil, nil +} diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index 3dc613a4..c818c80e 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" "strings" @@ -34,7 +35,7 @@ func GetToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + logger.SysError("invalid zhipu key: " + apikey) return "" } @@ -193,7 +194,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -202,13 +203,13 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var zhipuResponse StreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) + logger.SysError("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) + logger.SysError("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go new file mode 100644 index 00000000..658bfb90 --- /dev/null +++ b/relay/constant/api_type.go @@ -0,0 +1,69 @@ +package constant + +import ( + "one-api/common" +) + +const ( + APITypeOpenAI = iota + APITypeClaude + APITypePaLM + APITypeBaidu + APITypeZhipu + APITypeAli + APITypeXunfei + APITypeAIProxyLibrary + APITypeTencent + APITypeGemini +) + +func ChannelType2APIType(channelType int) int { + apiType := APITypeOpenAI + switch channelType { + case common.ChannelTypeAnthropic: + apiType = APITypeClaude + case common.ChannelTypeBaidu: + apiType = APITypeBaidu + case common.ChannelTypePaLM: + apiType = APITypePaLM + case common.ChannelTypeZhipu: + apiType = APITypeZhipu + case common.ChannelTypeAli: + apiType = APITypeAli + case common.ChannelTypeXunfei: + apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary + case common.ChannelTypeTencent: + apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini + } + return apiType +} + +//func GetAdaptor(apiType int) channel.Adaptor { +// switch apiType { +// case APITypeOpenAI: +// return &openai.Adaptor{} +// case APITypeClaude: +// return &anthropic.Adaptor{} +// case APITypePaLM: +// return &google.Adaptor{} +// case APITypeZhipu: +// return &baidu.Adaptor{} +// case APITypeBaidu: +// return &baidu.Adaptor{} +// case APITypeAli: +// return &ali.Adaptor{} +// case APITypeXunfei: +// return &xunfei.Adaptor{} +// case APITypeAIProxyLibrary: +// return &aiproxy.Adaptor{} +// case APITypeTencent: +// return &tencent.Adaptor{} +// case APITypeGemini: +// return &google.Adaptor{} +// } +// return nil +//} diff --git a/relay/constant/common.go b/relay/constant/common.go new file mode 100644 index 00000000..b6606cc6 --- /dev/null +++ b/relay/constant/common.go @@ -0,0 +1,3 @@ +package constant + +var StopFinishReason = "stop" diff --git a/relay/constant/main.go b/relay/constant/main.go deleted file mode 100644 index b3aeaaff..00000000 --- a/relay/constant/main.go +++ /dev/null @@ -1,16 +0,0 @@ -package constant - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - -var StopFinishReason = "stop" diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go new file mode 100644 index 00000000..5e2fe574 --- /dev/null +++ b/relay/constant/relay_mode.go @@ -0,0 +1,42 @@ +package constant + +import "strings" + +const ( + RelayModeUnknown = iota + RelayModeChatCompletions + RelayModeCompletions + RelayModeEmbeddings + RelayModeModerations + RelayModeImagesGenerations + RelayModeEdits + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation +) + +func Path2RelayMode(path string) int { + relayMode := RelayModeUnknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = RelayModeChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = RelayModeCompletions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = RelayModeModerations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = RelayModeImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = RelayModeEdits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = RelayModeAudioTranslation + } + return relayMode +} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 08d9af2a..d8a896de 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -102,7 +103,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := util.GetAPIVersion(c) + apiVersion := util.GetAzureAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } @@ -191,7 +192,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode // negative means add quota back for token & user err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) if err != nil { - common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) } }() }(c.Request.Context()) diff --git a/relay/controller/image.go b/relay/controller/image.go index be5fc3dd..9502a4d7 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" "one-api/relay/util" @@ -112,7 +113,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := util.GetAPIVersion(c) + apiVersion := util.GetAzureAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) } @@ -175,11 +176,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode } err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") diff --git a/relay/controller/text.go b/relay/controller/text.go index b17ff950..968cc751 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -1,206 +1,46 @@ package controller import ( - "bytes" "context" - "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" - "io" "math" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" - "one-api/relay/channel/aiproxy" - "one-api/relay/channel/ali" - "one-api/relay/channel/anthropic" - "one-api/relay/channel/baidu" - "one-api/relay/channel/google" "one-api/relay/channel/openai" - "one-api/relay/channel/tencent" - "one-api/relay/channel/xunfei" - "one-api/relay/channel/zhipu" "one-api/relay/constant" "one-api/relay/util" "strings" ) -const ( - APITypeOpenAI = iota - APITypeClaude - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini -) - func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - group := c.GetString("group") + ctx := c.Request.Context() + meta := util.GetRelayMeta(c) var textRequest openai.GeneralOpenAIRequest err := common.UnmarshalBodyReusable(c, &textRequest) if err != nil { return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } - if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { - return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) - } if relayMode == constant.RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } - // request validation - if textRequest.Model == "" { - return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + err = util.ValidateTextRequest(&textRequest, relayMode) + if err != nil { + return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) } - switch relayMode { - case constant.RelayModeCompletions: - if textRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) - } - case constant.RelayModeChatCompletions: - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) - } - case constant.RelayModeEmbeddings: - case constant.RelayModeModerations: - if textRequest.Input == "" { - return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - case constant.RelayModeEdits: - if textRequest.Instruction == "" { - return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) - } - } - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - textRequest.Model = modelMap[textRequest.Model] - isModelMapped = true - } - } - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeClaude - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - apiVersion := util.GetAPIVersion(c) - requestURL := strings.Split(requestURL, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) - baseURL = c.GetString("base_url") - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType) - } - case APITypeClaude: - fullRequestURL = "https://api.anthropic.com/v1/complete" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) - } - case APITypeBaidu: - switch textRequest.Model { - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - var err error - if apiKey, err = baidu.GetAccessToken(apiKey); err != nil { - return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) - } - fullRequestURL += "?access_token=" + apiKey - case APITypePaLM: - fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) - } - case APITypeGemini: - requestBaseURL := "https://generativelanguage.googleapis.com" - if baseURL != "" { - requestBaseURL = baseURL - } - version := "v1" - if c.GetString("api_version") != "" { - version = c.GetString("api_version") - } - action := "generateContent" - if textRequest.Stream { - action = "streamGenerateContent" - } - fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) - case APITypeZhipu: - method := "invoke" - if textRequest.Stream { - method = "sse-invoke" - } - fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) - case APITypeAli: - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == constant.RelayModeEmbeddings { - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" - } - case APITypeTencent: - fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" - case APITypeAIProxyLibrary: - fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) + var isModelMapped bool + textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) + apiType := constant.ChannelType2APIType(meta.ChannelType) + fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest) + if err != nil { + logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error())) + return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError) } var promptTokens int var completionTokens int @@ -217,17 +57,17 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode preConsumedTokens = promptTokens + textRequest.MaxTokens } modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(group) + groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(meta.UserId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) if err != nil { return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } @@ -235,165 +75,28 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) + logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) } if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) if err != nil { return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body + requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode) + if err != nil { + return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError) } - switch apiType { - case APITypeClaude: - claudeRequest := anthropic.ConvertRequest(textRequest) - jsonStr, err := json.Marshal(claudeRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeBaidu: - var jsonData []byte - var err error - switch relayMode { - case constant.RelayModeEmbeddings: - baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest) - jsonData, err = json.Marshal(baiduEmbeddingRequest) - default: - baiduRequest := baidu.ConvertRequest(textRequest) - jsonData, err = json.Marshal(baiduRequest) - } - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonData) - case APITypePaLM: - palmRequest := google.ConvertPaLMRequest(textRequest) - jsonStr, err := json.Marshal(palmRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeGemini: - geminiChatRequest := google.ConvertGeminiRequest(textRequest) - jsonStr, err := json.Marshal(geminiChatRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeZhipu: - zhipuRequest := zhipu.ConvertRequest(textRequest) - jsonStr, err := json.Marshal(zhipuRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAli: - var jsonStr []byte - var err error - switch relayMode { - case constant.RelayModeEmbeddings: - aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest) - jsonStr, err = json.Marshal(aliEmbeddingRequest) - default: - aliRequest := ali.ConvertRequest(textRequest) - jsonStr, err = json.Marshal(aliRequest) - } - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeTencent: - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - appId, secretId, secretKey, err := tencent.ParseConfig(apiKey) - if err != nil { - return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) - } - tencentRequest := tencent.ConvertRequest(textRequest) - tencentRequest.AppId = appId - tencentRequest.SecretId = secretId - jsonStr, err := json.Marshal(tencentRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - sign := tencent.GetSign(*tencentRequest, secretKey) - c.Request.Header.Set("Authorization", sign) - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAIProxyLibrary: - aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") - jsonStr, err := json.Marshal(aiProxyLibraryRequest) - if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } - var req *http.Request var resp *http.Response isStream := textRequest.Stream - if apiType != APITypeXunfei { // cause xunfei use websocket + if apiType != constant.APITypeXunfei { // cause xunfei use websocket req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - req.Header.Set("api-key", apiKey) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - if channelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - } - case APITypeClaude: - req.Header.Set("x-api-key", apiKey) - anthropicVersion := c.Request.Header.Get("anthropic-version") - if anthropicVersion == "" { - anthropicVersion = "2023-06-01" - } - req.Header.Set("anthropic-version", anthropicVersion) - case APITypeZhipu: - token := zhipu.GetToken(apiKey) - req.Header.Set("Authorization", token) - case APITypeAli: - req.Header.Set("Authorization", "Bearer "+apiKey) - if textRequest.Stream { - req.Header.Set("X-DashScope-SSE", "enable") - } - if c.GetString("plugin") != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) - } - case APITypeTencent: - req.Header.Set("Authorization", apiKey) - case APITypePaLM: - req.Header.Set("x-goog-api-key", apiKey) - case APITypeGemini: - req.Header.Set("x-goog-api-key", apiKey) - default: - req.Header.Set("Authorization", "Bearer "+apiKey) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if isStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") - } - //req.Header.Set("Connection", c.Request.Header.Get("Connection")) + SetupRequestHeaders(c, req, apiType, meta, isStream) resp, err = util.HTTPClient.Do(req) if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) @@ -409,29 +112,31 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } + util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return util.RelayErrorHandler(resp) } } - var textResponse openai.SlimTextResponse - tokenName := c.GetString("token_name") + var respErr *openai.ErrorWithStatusCode + var usage *openai.Usage defer func(ctx context.Context) { - // c.Writer.Flush() + // Why we use defer here? Because if error happened, we will have to return the pre-consumed quota. + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + return + } + if usage == nil { + logger.Error(ctx, "usage is nil, which is unexpected") + return + } + go func() { quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens + promptTokens = usage.PromptTokens + completionTokens = usage.CompletionTokens quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 @@ -443,239 +148,25 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode quota = 0 } quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(meta.UserId) if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) + logger.Error(ctx, "error update user quota cache: "+err.Error()) } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) } - }() - }(c.Request.Context()) - switch apiType { - case APITypeOpenAI: - if isStream { - err, responseText := openai.StreamHandler(c, resp, relayMode) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeClaude: - if isStream { - err, responseText := anthropic.StreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeBaidu: - if isStream { - err, usage := baidu.StreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *openai.ErrorWithStatusCode - var usage *openai.Usage - switch relayMode { - case constant.RelayModeEmbeddings: - err, usage = baidu.EmbeddingHandler(c, resp) - default: - err, usage = baidu.Handler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypePaLM: - if textRequest.Stream { // PaLM2 API does not support stream - err, responseText := google.PaLMStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeGemini: - if textRequest.Stream { - err, responseText := google.StreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeZhipu: - if isStream { - err, usage := zhipu.StreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } else { - err, usage := zhipu.Handler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } - case APITypeAli: - if isStream { - err, usage := ali.StreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *openai.ErrorWithStatusCode - var usage *openai.Usage - switch relayMode { - case constant.RelayModeEmbeddings: - err, usage = ali.EmbeddingHandler(c, resp) - default: - err, usage = ali.Handler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeXunfei: - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - var err *openai.ErrorWithStatusCode - var usage *openai.Usage - if isStream { - err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - } else { - err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2]) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - case APITypeAIProxyLibrary: - if isStream { - err, usage := aiproxy.StreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - err, usage := aiproxy.Handler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeTencent: - if isStream { - err, responseText := tencent.StreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := tencent.Handler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - default: - return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) + }(ctx) + usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens) + if respErr != nil { + return respErr } + return nil } diff --git a/relay/controller/util.go b/relay/controller/util.go new file mode 100644 index 00000000..cdb10dbf --- /dev/null +++ b/relay/controller/util.go @@ -0,0 +1,336 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/relay/channel/aiproxy" + "one-api/relay/channel/ali" + "one-api/relay/channel/anthropic" + "one-api/relay/channel/baidu" + "one-api/relay/channel/google" + "one-api/relay/channel/openai" + "one-api/relay/channel/tencent" + "one-api/relay/channel/xunfei" + "one-api/relay/channel/zhipu" + "one-api/relay/constant" + "one-api/relay/util" + "strings" +) + +func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { + fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + switch apiType { + case constant.APITypeOpenAI: + if meta.ChannelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(requestURL, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := textRequest.Model + model_ = strings.Replace(model_, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model_ = strings.TrimSuffix(model_, "-0301") + model_ = strings.TrimSuffix(model_, "-0314") + model_ = strings.TrimSuffix(model_, "-0613") + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + } + case constant.APITypeClaude: + fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL) + case constant.APITypeBaidu: + switch textRequest.Model { + case "ERNIE-Bot": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Bot-turbo": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "ERNIE-Bot-4": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + case "BLOOMZ-7B": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + } + var accessToken string + var err error + if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil { + return "", fmt.Errorf("failed to get baidu access token: %w", err) + } + fullRequestURL += "?access_token=" + accessToken + case constant.APITypePaLM: + fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL) + case constant.APITypeGemini: + version := common.AssignOrDefault(meta.APIVersion, "v1") + action := "generateContent" + if textRequest.Stream { + action = "streamGenerateContent" + } + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action) + case constant.APITypeZhipu: + method := "invoke" + if textRequest.Stream { + method = "sse-invoke" + } + fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) + case constant.APITypeAli: + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + if relayMode == constant.RelayModeEmbeddings { + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } + case constant.APITypeTencent: + fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" + case constant.APITypeAIProxyLibrary: + fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL) + } + return fullRequestURL, nil +} + +func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) { + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + switch apiType { + case constant.APITypeClaude: + claudeRequest := anthropic.ConvertRequest(textRequest) + jsonStr, err := json.Marshal(claudeRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeBaidu: + var jsonData []byte + var err error + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest) + jsonData, err = json.Marshal(baiduEmbeddingRequest) + default: + baiduRequest := baidu.ConvertRequest(textRequest) + jsonData, err = json.Marshal(baiduRequest) + } + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonData) + case constant.APITypePaLM: + palmRequest := google.ConvertPaLMRequest(textRequest) + jsonStr, err := json.Marshal(palmRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeGemini: + geminiChatRequest := google.ConvertGeminiRequest(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeZhipu: + zhipuRequest := zhipu.ConvertRequest(textRequest) + jsonStr, err := json.Marshal(zhipuRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeAli: + var jsonStr []byte + var err error + switch relayMode { + case constant.RelayModeEmbeddings: + aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest) + jsonStr, err = json.Marshal(aliEmbeddingRequest) + default: + aliRequest := ali.ConvertRequest(textRequest) + jsonStr, err = json.Marshal(aliRequest) + } + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeTencent: + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + appId, secretId, secretKey, err := tencent.ParseConfig(apiKey) + if err != nil { + return nil, err + } + tencentRequest := tencent.ConvertRequest(textRequest) + tencentRequest.AppId = appId + tencentRequest.SecretId = secretId + jsonStr, err := json.Marshal(tencentRequest) + if err != nil { + return nil, err + } + sign := tencent.GetSign(*tencentRequest, secretKey) + c.Request.Header.Set("Authorization", sign) + requestBody = bytes.NewBuffer(jsonStr) + case constant.APITypeAIProxyLibrary: + aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) + aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + jsonStr, err := json.Marshal(aiProxyLibraryRequest) + if err != nil { + return nil, err + } + requestBody = bytes.NewBuffer(jsonStr) + } + return requestBody, nil +} + +func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { + SetupAuthHeaders(c, req, apiType, meta, isStream) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if isStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } +} + +func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { + apiKey := meta.APIKey + switch apiType { + case constant.APITypeOpenAI: + if meta.ChannelType == common.ChannelTypeAzure { + req.Header.Set("api-key", apiKey) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + if meta.ChannelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + } + case constant.APITypeClaude: + req.Header.Set("x-api-key", apiKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) + case constant.APITypeZhipu: + token := zhipu.GetToken(apiKey) + req.Header.Set("Authorization", token) + case constant.APITypeAli: + req.Header.Set("Authorization", "Bearer "+apiKey) + if isStream { + req.Header.Set("X-DashScope-SSE", "enable") + } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } + case constant.APITypeTencent: + req.Header.Set("Authorization", apiKey) + case constant.APITypePaLM: + req.Header.Set("x-goog-api-key", apiKey) + case constant.APITypeGemini: + req.Header.Set("x-goog-api-key", apiKey) + default: + req.Header.Set("Authorization", "Bearer "+apiKey) + } +} + +func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) { + var responseText string + switch apiType { + case constant.APITypeOpenAI: + if isStream { + err, responseText = openai.StreamHandler(c, resp, relayMode) + } else { + err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeClaude: + if isStream { + err, responseText = anthropic.StreamHandler(c, resp) + } else { + err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeBaidu: + if isStream { + err, usage = baidu.StreamHandler(c, resp) + } else { + switch relayMode { + case constant.RelayModeEmbeddings: + err, usage = baidu.EmbeddingHandler(c, resp) + default: + err, usage = baidu.Handler(c, resp) + } + } + case constant.APITypePaLM: + if isStream { // PaLM2 API does not support stream + err, responseText = google.PaLMStreamHandler(c, resp) + } else { + err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeGemini: + if isStream { + err, responseText = google.StreamHandler(c, resp) + } else { + err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model) + } + case constant.APITypeZhipu: + if isStream { + err, usage = zhipu.StreamHandler(c, resp) + } else { + err, usage = zhipu.Handler(c, resp) + } + case constant.APITypeAli: + if isStream { + err, usage = ali.StreamHandler(c, resp) + } else { + switch relayMode { + case constant.RelayModeEmbeddings: + err, usage = ali.EmbeddingHandler(c, resp) + default: + err, usage = ali.Handler(c, resp) + } + } + case constant.APITypeXunfei: + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + if isStream { + err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2]) + } + case constant.APITypeAIProxyLibrary: + if isStream { + err, usage = aiproxy.StreamHandler(c, resp) + } else { + err, usage = aiproxy.Handler(c, resp) + } + case constant.APITypeTencent: + if isStream { + err, responseText = tencent.StreamHandler(c, resp) + } else { + err, usage = tencent.Handler(c, resp) + } + default: + return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) + } + if err != nil { + return nil, err + } + if usage == nil && responseText != "" { + usage = &openai.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + return usage, nil +} diff --git a/relay/util/billing.go b/relay/util/billing.go new file mode 100644 index 00000000..35fb28a4 --- /dev/null +++ b/relay/util/billing.go @@ -0,0 +1,19 @@ +package util + +import ( + "context" + "one-api/common/logger" + "one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} diff --git a/relay/util/common.go b/relay/util/common.go index 9d13b12e..d7596188 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" "strconv" @@ -138,11 +139,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo // quotaDelta is remaining quota to be consumed err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + logger.SysError("error consuming token remain quota: " + err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + logger.SysError("error update user quota cache: " + err.Error()) } // totalQuota is total quota consumed if totalQuota != 0 { @@ -152,11 +153,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo model.UpdateChannelUsedQuota(channelId, totalQuota) } if totalQuota <= 0 { - common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } } -func GetAPIVersion(c *gin.Context) string { +func GetAzureAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go new file mode 100644 index 00000000..39e062a1 --- /dev/null +++ b/relay/util/model_mapping.go @@ -0,0 +1,12 @@ +package util + +func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) { + if mapping == nil { + return modelName, false + } + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName, true + } + return modelName, false +} diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go new file mode 100644 index 00000000..19936e49 --- /dev/null +++ b/relay/util/relay_meta.go @@ -0,0 +1,44 @@ +package util + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + "strings" +) + +type RelayMeta struct { + ChannelType int + ChannelId int + TokenId int + TokenName string + UserId int + Group string + ModelMapping map[string]string + BaseURL string + APIVersion string + APIKey string + Config map[string]string +} + +func GetRelayMeta(c *gin.Context) *RelayMeta { + meta := RelayMeta{ + ChannelType: c.GetInt("channel"), + ChannelId: c.GetInt("channel_id"), + TokenId: c.GetInt("token_id"), + TokenName: c.GetString("token_name"), + UserId: c.GetInt("id"), + Group: c.GetString("group"), + ModelMapping: c.GetStringMapString("model_mapping"), + BaseURL: c.GetString("base_url"), + APIVersion: c.GetString("api_version"), + APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Config: nil, + } + if meta.ChannelType == common.ChannelTypeAzure { + meta.APIVersion = GetAzureAPIVersion(c) + } + if meta.BaseURL == "" { + meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] + } + return &meta +} diff --git a/relay/util/validation.go b/relay/util/validation.go new file mode 100644 index 00000000..48b42d94 --- /dev/null +++ b/relay/util/validation.go @@ -0,0 +1,37 @@ +package util + +import ( + "errors" + "math" + "one-api/relay/channel/openai" + "one-api/relay/constant" +) + +func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error { + if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { + return errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return errors.New("model is required") + } + switch relayMode { + case constant.RelayModeCompletions: + if textRequest.Prompt == "" { + return errors.New("field prompt is required") + } + case constant.RelayModeChatCompletions: + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return errors.New("field messages is required") + } + case constant.RelayModeEmbeddings: + case constant.RelayModeModerations: + if textRequest.Input == "" { + return errors.New("field input is required") + } + case constant.RelayModeEdits: + if textRequest.Instruction == "" { + return errors.New("field instruction is required") + } + } + return nil +} diff --git a/router/main.go b/router/main.go index 85127a1a..733a1033 100644 --- a/router/main.go +++ b/router/main.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/logger" "os" "strings" ) @@ -17,7 +18,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) { frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - common.SysLog("FRONTEND_BASE_URL is ignored on master node") + logger.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS)