diff --git a/common/constants.go b/common/constants.go
index 9ee791df..a83de9c1 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -91,16 +91,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
-var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
+var SyncFrequency = GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
-var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
+var BatchUpdateInterval = GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
-var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
+var RelayTimeout = GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
-var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
+var GeminiSafetySetting = GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
-var Theme = GetOrDefaultString("THEME", "default")
+var Theme = GetOrDefaultEnvString("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
@@ -127,10 +127,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
- GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
+ GlobalApiRateLimitNum = GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
- GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
+ GlobalWebRateLimitNum = GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@@ -199,29 +199,29 @@ const (
)
var ChannelBaseURLs = []string{
- "", // 0
- "https://api.openai.com", // 1
- "https://oa.api2d.net", // 2
- "", // 3
- "https://api.closeai-proxy.xyz", // 4
- "https://api.openai-sb.com", // 5
- "https://api.openaimax.com", // 6
- "https://api.ohmygpt.com", // 7
- "", // 8
- "https://api.caipacity.com", // 9
- "https://api.aiproxy.io", // 10
- "", // 11
- "https://api.api2gpt.com", // 12
- "https://api.aigc2d.com", // 13
- "https://api.anthropic.com", // 14
- "https://aip.baidubce.com", // 15
- "https://open.bigmodel.cn", // 16
- "https://dashscope.aliyuncs.com", // 17
- "", // 18
- "https://ai.360.cn", // 19
- "https://openrouter.ai/api", // 20
- "https://api.aiproxy.io", // 21
- "https://fastgpt.run/api/openapi", // 22
- "https://hunyuan.cloud.tencent.com", //23
- "", //24
+ "", // 0
+ "https://api.openai.com", // 1
+ "https://oa.api2d.net", // 2
+ "", // 3
+ "https://api.closeai-proxy.xyz", // 4
+ "https://api.openai-sb.com", // 5
+ "https://api.openaimax.com", // 6
+ "https://api.ohmygpt.com", // 7
+ "", // 8
+ "https://api.caipacity.com", // 9
+ "https://api.aiproxy.io", // 10
+ "https://generativelanguage.googleapis.com", // 11
+ "https://api.api2gpt.com", // 12
+ "https://api.aigc2d.com", // 13
+ "https://api.anthropic.com", // 14
+ "https://aip.baidubce.com", // 15
+ "https://open.bigmodel.cn", // 16
+ "https://dashscope.aliyuncs.com", // 17
+ "", // 18
+ "https://ai.360.cn", // 19
+ "https://openrouter.ai/api", // 20
+ "https://api.aiproxy.io", // 21
+ "https://fastgpt.run/api/openapi", // 22
+ "https://hunyuan.cloud.tencent.com", // 23
+ "https://generativelanguage.googleapis.com", // 24
}
diff --git a/common/database.go b/common/database.go
index 76f2cd55..8f659b57 100644
--- a/common/database.go
+++ b/common/database.go
@@ -4,4 +4,4 @@ var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
-var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
+var SQLiteBusyTimeout = GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
diff --git a/common/group-ratio.go b/common/group-ratio.go
index 1ec73c78..86e9e1f6 100644
--- a/common/group-ratio.go
+++ b/common/group-ratio.go
@@ -1,6 +1,9 @@
package common
-import "encoding/json"
+import (
+ "encoding/json"
+ "one-api/common/logger"
+)
var GroupRatio = map[string]float64{
"default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
- SysError("group ratio not found: " + name)
+ logger.SysError("group ratio not found: " + name)
return 1
}
return ratio
diff --git a/common/init.go b/common/init.go
index 12df5f51..9735c5b4 100644
--- a/common/init.go
+++ b/common/init.go
@@ -4,6 +4,7 @@ import (
"flag"
"fmt"
"log"
+ "one-api/common/logger"
"os"
"path/filepath"
)
@@ -37,7 +38,7 @@ func init() {
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
- SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
+ logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
SessionSecret = os.Getenv("SESSION_SECRET")
}
diff --git a/common/logger.go b/common/logger/logger.go
similarity index 71%
rename from common/logger.go
rename to common/logger/logger.go
index 61627217..4386bc6c 100644
--- a/common/logger.go
+++ b/common/logger/logger.go
@@ -1,4 +1,4 @@
-package common
+package logger
import (
"context"
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"log"
+ "one-api/common"
"os"
"path/filepath"
"sync"
@@ -25,7 +26,7 @@ var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
- if *LogDir != "" {
+ if *common.LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,7 +36,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
- logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
+ logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -55,24 +56,36 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
-func LogInfo(ctx context.Context, msg string) {
+func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
-func LogWarn(ctx context.Context, msg string) {
+func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
-func LogError(ctx context.Context, msg string) {
+func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
+func Infof(ctx context.Context, format string, a ...any) {
+ Info(ctx, fmt.Sprintf(format, a))
+}
+
+func Warnf(ctx context.Context, format string, a ...any) {
+ Warn(ctx, fmt.Sprintf(format, a))
+}
+
+func Errorf(ctx context.Context, format string, a ...any) {
+ Error(ctx, fmt.Sprintf(format, a))
+}
+
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
- id := ctx.Value(RequestIdKey)
+ id := ctx.Value(common.RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
@@ -92,8 +105,8 @@ func FatalLog(v ...any) {
}
func LogQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+ if common.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 97cb060d..9f31e0d7 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -2,6 +2,7 @@ package common
import (
"encoding/json"
+ "one-api/common/logger"
"strings"
"time"
)
@@ -107,7 +108,7 @@ var ModelRatio = map[string]float64{
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
+ logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -123,7 +124,7 @@ func GetModelRatio(name string) float64 {
}
ratio, ok := ModelRatio[name]
if !ok {
- SysError("model ratio not found: " + name)
+ logger.SysError("model ratio not found: " + name)
return 30
}
return ratio
diff --git a/common/redis.go b/common/redis.go
index 12c477b8..ed3fcd9d 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
+ "one-api/common/logger"
"os"
"time"
)
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
- SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
+ logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
- SysLog("SYNC_FREQUENCY not set, Redis is disabled")
+ logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
- SysLog("Redis is enabled")
+ logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
+ logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
- FatalLog("Redis ping test failed: " + err.Error())
+ logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
+ logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}
diff --git a/common/utils.go b/common/utils.go
index 9a7038e2..4e3312f9 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -7,6 +7,7 @@ import (
"log"
"math/rand"
"net"
+ "one-api/common/logger"
"os"
"os/exec"
"runtime"
@@ -184,25 +185,32 @@ func Max(a int, b int) int {
}
}
-func GetOrDefault(env string, defaultValue int) int {
+func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
- SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
+ logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
-func GetOrDefaultString(env string, defaultValue string) string {
+func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
+func AssignOrDefault(value string, defaultValue string) string {
+ if len(value) != 0 {
+ return value
+ }
+ return defaultValue
+}
+
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 29346cde..61a899a4 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"one-api/relay/util"
"strconv"
@@ -339,8 +340,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("updating all channels")
+ logger.SysLog("updating all channels")
_ = updateAllChannelsBalance()
- common.SysLog("channels update done")
+ logger.SysLog("channels update done")
}
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index f64f0ee3..73ff6bb2 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/util"
@@ -155,7 +156,7 @@ func notifyRootUser(subject string, content string) {
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
@@ -221,7 +222,7 @@ func testAllChannels(notify bool) error {
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
@@ -247,8 +248,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("testing all channels")
+ logger.SysLog("testing all channels")
_ = testAllChannels(false)
- common.SysLog("channel test finished")
+ logger.SysLog("channel test finished")
}
}
diff --git a/controller/github.go b/controller/github.go
index ee995379..68692b9d 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"strconv"
"time"
@@ -46,7 +47,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -62,7 +63,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
- common.SysLog(err.Error())
+ logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
diff --git a/controller/relay.go b/controller/relay.go
index 198d6c9a..e390ae75 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -2,43 +2,21 @@ package controller
import (
"fmt"
+ "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"one-api/relay/controller"
"one-api/relay/util"
"strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/chat
func Relay(c *gin.Context) {
- relayMode := constant.RelayModeUnknown
- if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
- relayMode = constant.RelayModeChatCompletions
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
- relayMode = constant.RelayModeCompletions
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
- relayMode = constant.RelayModeEmbeddings
- } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- relayMode = constant.RelayModeEmbeddings
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- relayMode = constant.RelayModeModerations
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- relayMode = constant.RelayModeImagesGenerations
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
- relayMode = constant.RelayModeEdits
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
- relayMode = constant.RelayModeAudioSpeech
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
- relayMode = constant.RelayModeAudioTranscription
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
- relayMode = constant.RelayModeAudioTranslation
- }
+ relayMode := constant.Path2RelayMode(c.Request.URL.Path)
var err *openai.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
@@ -71,7 +49,7 @@ func Relay(c *gin.Context) {
})
}
channelId := c.GetInt("channel_id")
- common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
+ logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
channelId := c.GetInt("channel_id")
diff --git a/controller/user.go b/controller/user.go
index 174300ed..d39bba3b 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"strconv"
"time"
@@ -409,7 +410,7 @@ func UpdateUser(c *gin.Context) {
return
}
if originUser.Quota != updatedUser.Quota {
- model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
+ model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
diff --git a/main.go b/main.go
index 28a41287..9e4a88f2 100644
--- a/main.go
+++ b/main.go
@@ -7,6 +7,7 @@ import (
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"one-api/common"
+ "one-api/common/logger"
"one-api/controller"
"one-api/middleware"
"one-api/model"
@@ -20,42 +21,42 @@ import (
var buildFS embed.FS
func main() {
- common.SetupLogger()
- common.SysLog(fmt.Sprintf("One API %s started", common.Version))
+ logger.SetupLogger()
+ logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
if common.DebugEnabled {
- common.SysLog("running in debug mode")
+ logger.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
+ logger.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
- common.FatalLog("failed to close database: " + err.Error())
+ logger.FatalLog("failed to close database: " + err.Error())
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
- common.FatalLog("failed to initialize Redis: " + err.Error())
+ logger.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize options
model.InitOptionMap()
- common.SysLog(fmt.Sprintf("using theme %s", common.Theme))
+ logger.SysLog(fmt.Sprintf("using theme %s", common.Theme))
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
- common.SysLog("memory cache enabled")
- common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+ logger.SysLog("memory cache enabled")
+ logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.MemoryCacheEnabled {
@@ -65,20 +66,20 @@ func main() {
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
- common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+ logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
- common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+ logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
openai.InitTokenEncoders()
@@ -101,6 +102,6 @@ func main() {
}
err = server.Run(":" + port)
if err != nil {
- common.FatalLog("failed to start HTTP server: " + err.Error())
+ logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 81338130..6b607d68 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"strconv"
"strings"
@@ -69,7 +70,7 @@ func Distribute() func(c *gin.Context) {
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
abortWithMessage(c, http.StatusServiceUnavailable, message)
diff --git a/middleware/recover.go b/middleware/recover.go
index 8338a514..9d3edc27 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -4,7 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
- "one-api/common"
+ "one-api/common/logger"
"runtime/debug"
)
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
- common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+ logger.SysError(fmt.Sprintf("panic detected: %v", err))
+ logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
index 26688810..6f295864 100644
--- a/middleware/turnstile-check.go
+++ b/middleware/turnstile-check.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "one-api/common/logger"
)
type turnstileCheckResponse struct {
@@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc {
"remoteip": {c.ClientIP()},
})
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
- common.SysError(err.Error())
+ logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/middleware/utils.go b/middleware/utils.go
index 536125cc..31620bf2 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -3,6 +3,7 @@ package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
+ "one-api/common/logger"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -13,5 +14,5 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
},
})
c.Abort()
- common.LogError(c.Request.Context(), message)
+ logger.Error(c.Request.Context(), message)
}
diff --git a/model/cache.go b/model/cache.go
index c6d0c70a..eaed5bba 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"one-api/common"
+ "one-api/common/logger"
"sort"
"strconv"
"strings"
@@ -42,7 +43,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set token error: " + err.Error())
+ logger.SysError("Redis set token error: " + err.Error())
}
return &token, nil
}
@@ -62,7 +63,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user group error: " + err.Error())
+ logger.SysError("Redis set user group error: " + err.Error())
}
}
return group, err
@@ -80,7 +81,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user quota error: " + err.Error())
+ logger.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
}
@@ -127,7 +128,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
- common.SysError("Redis set user enabled error: " + err.Error())
+ logger.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
@@ -178,13 +179,13 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
+ logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
+ logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
diff --git a/model/channel.go b/model/channel.go
index 7e7b42e6..d89d1666 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,8 +1,11 @@
package model
import (
+ "encoding/json"
+ "fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/logger"
)
type Channel struct {
@@ -86,11 +89,17 @@ func (channel *Channel) GetBaseURL() string {
return *channel.BaseURL
}
-func (channel *Channel) GetModelMapping() string {
- if channel.ModelMapping == nil {
- return ""
+func (channel *Channel) GetModelMapping() map[string]string {
+ if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
+ return nil
}
- return *channel.ModelMapping
+ modelMapping := make(map[string]string)
+ err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
+ if err != nil {
+ logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
+ return nil
+ }
+ return modelMapping
}
func (channel *Channel) Insert() error {
@@ -120,7 +129,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
ResponseTime: int(responseTime),
}).Error
if err != nil {
- common.SysError("failed to update response time: " + err.Error())
+ logger.SysError("failed to update response time: " + err.Error())
}
}
@@ -130,7 +139,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
Balance: balance,
}).Error
if err != nil {
- common.SysError("failed to update balance: " + err.Error())
+ logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -147,11 +156,11 @@ func (channel *Channel) Delete() error {
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
- common.SysError("failed to update ability status: " + err.Error())
+ logger.SysError("failed to update ability status: " + err.Error())
}
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
- common.SysError("failed to update channel status: " + err.Error())
+ logger.SysError("failed to update channel status: " + err.Error())
}
}
@@ -166,7 +175,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
- common.SysError("failed to update channel used quota: " + err.Error())
+ logger.SysError("failed to update channel used quota: " + err.Error())
}
}
diff --git a/model/log.go b/model/log.go
index 06085acf..728c4b17 100644
--- a/model/log.go
+++ b/model/log.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"one-api/common"
+ "one-api/common/logger"
"gorm.io/gorm"
)
@@ -44,12 +45,12 @@ func RecordLog(userId int, logType int, content string) {
}
err := DB.Create(log).Error
if err != nil {
- common.SysError("failed to record log: " + err.Error())
+ logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
- common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+ logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
@@ -68,7 +69,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
err := DB.Create(log).Error
if err != nil {
- common.LogError(ctx, "failed to record log: "+err.Error())
+ logger.Error(ctx, "failed to record log: "+err.Error())
}
}
diff --git a/model/main.go b/model/main.go
index 9723e638..0b9c4f2b 100644
--- a/model/main.go
+++ b/model/main.go
@@ -7,6 +7,7 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/logger"
"os"
"strings"
"time"
@@ -18,7 +19,7 @@ func createRootAccountIfNeed() error {
var user User
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
- common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
+ logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -42,7 +43,7 @@ func chooseDB() (*gorm.DB, error) {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
- common.SysLog("using PostgreSQL as database")
+ logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
@@ -52,13 +53,13 @@ func chooseDB() (*gorm.DB, error) {
})
}
// Use MySQL
- common.SysLog("using MySQL as database")
+ logger.SysLog("using MySQL as database")
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
- common.SysLog("SQL_DSN not set, using SQLite as database")
+ logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
@@ -77,14 +78,14 @@ func InitDB() (err error) {
if err != nil {
return err
}
- sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
- sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
- sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
+ sqlDB.SetMaxIdleConns(common.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
+ sqlDB.SetMaxOpenConns(common.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
+ sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
- common.SysLog("database migration started")
+ logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
@@ -113,11 +114,11 @@ func InitDB() (err error) {
if err != nil {
return err
}
- common.SysLog("database migrated")
+ logger.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
} else {
- common.FatalLog(err)
+ logger.FatalLog(err)
}
return err
}
diff --git a/model/option.go b/model/option.go
index 20575c9a..80abff20 100644
--- a/model/option.go
+++ b/model/option.go
@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
+ "one-api/common/logger"
"strconv"
"strings"
"time"
@@ -82,7 +83,7 @@ func loadOptionsFromDatabase() {
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
- common.SysError("failed to update option map: " + err.Error())
+ logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -90,7 +91,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing options from database")
+ logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
diff --git a/model/redemption.go b/model/redemption.go
index f16412b5..ba1e1077 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/logger"
)
type Redemption struct {
@@ -75,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
- RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
+ RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", logger.LogQuota(redemption.Quota)))
return redemption.Quota, nil
}
diff --git a/model/token.go b/model/token.go
index 2e53ac0b..570de47d 100644
--- a/model/token.go
+++ b/model/token.go
@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/logger"
)
type Token struct {
@@ -39,7 +40,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
token, err = CacheGetTokenByKey(key)
if err != nil {
- common.SysError("CacheGetTokenByKey failed: " + err.Error())
+ logger.SysError("CacheGetTokenByKey failed: " + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("无效的令牌")
}
@@ -58,7 +59,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
@@ -69,7 +70,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
@@ -202,7 +203,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
go func() {
email, err := GetUserEmail(token.UserId)
if err != nil {
- common.SysError("failed to fetch user email: " + err.Error())
+ logger.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
@@ -213,7 +214,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
- common.SysError("failed to send email" + err.Error())
+ logger.SysError("failed to send email" + err.Error())
}
}
}()
diff --git a/model/user.go b/model/user.go
index 1c2c0a75..17f94d9f 100644
--- a/model/user.go
+++ b/model/user.go
@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
+ "one-api/common/logger"
"strings"
)
@@ -97,16 +98,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error
}
if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
}
}
return nil
@@ -232,7 +233,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
- common.SysError("no such user " + err.Error())
+ logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -341,7 +342,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota and request count: " + err.Error())
+ logger.SysError("failed to update user used quota and request count: " + err.Error())
}
}
@@ -352,14 +353,14 @@ func updateUserUsedQuota(id int, quota int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota: " + err.Error())
+ logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
- common.SysError("failed to update user request count: " + err.Error())
+ logger.SysError("failed to update user request count: " + err.Error())
}
}
diff --git a/model/utils.go b/model/utils.go
index 1c28340b..e4797a78 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
+ "one-api/common/logger"
"sync"
"time"
)
@@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
- common.SysLog("batch update started")
+ logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
@@ -57,12 +58,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
- common.SysError("failed to batch update user quota: " + err.Error())
+ logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
- common.SysError("failed to batch update token quota: " + err.Error())
+ logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
@@ -73,5 +74,5 @@ func batchUpdate() {
}
}
}
- common.SysLog("batch update finished")
+ logger.SysLog("batch update finished")
}
diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go
new file mode 100644
index 00000000..44b6f58d
--- /dev/null
+++ b/relay/channel/aiproxy/adaptor.go
@@ -0,0 +1,22 @@
+package aiproxy
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/aiproxy/main.go b/relay/channel/aiproxy/main.go
index bee4d9d3..63fef55e 100644
--- a/relay/channel/aiproxy/main.go
+++ b/relay/channel/aiproxy/main.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"strconv"
@@ -122,7 +123,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
@@ -131,7 +132,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -140,7 +141,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
new file mode 100644
index 00000000..49022cfc
--- /dev/null
+++ b/relay/channel/ali/adaptor.go
@@ -0,0 +1,22 @@
+package ali
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go
index f45a515a..c5ada0d7 100644
--- a/relay/channel/ali/main.go
+++ b/relay/channel/ali/main.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"strings"
)
@@ -185,7 +186,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -198,7 +199,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
diff --git a/relay/channel/anthropic/adaptor.go b/relay/channel/anthropic/adaptor.go
new file mode 100644
index 00000000..55577228
--- /dev/null
+++ b/relay/channel/anthropic/adaptor.go
@@ -0,0 +1,22 @@
+package anthropic
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go
index a4272d7b..006779b2 100644
--- a/relay/channel/anthropic/main.go
+++ b/relay/channel/anthropic/main.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"strings"
)
@@ -125,7 +126,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var claudeResponse Response
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
@@ -134,7 +135,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
new file mode 100644
index 00000000..498b664a
--- /dev/null
+++ b/relay/channel/baidu/adaptor.go
@@ -0,0 +1,22 @@
+package baidu
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go
index 47969492..f5b98155 100644
--- a/relay/channel/baidu/main.go
+++ b/relay/channel/baidu/main.go
@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"one-api/relay/util"
@@ -19,49 +20,49 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
-type BaiduTokenResponse struct {
+type TokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
-type BaiduMessage struct {
+type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
-type BaiduChatRequest struct {
- Messages []BaiduMessage `json:"messages"`
- Stream bool `json:"stream"`
- UserId string `json:"user_id,omitempty"`
+type ChatRequest struct {
+ Messages []Message `json:"messages"`
+ Stream bool `json:"stream"`
+ UserId string `json:"user_id,omitempty"`
}
-type BaiduError struct {
+type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
var baiduTokenStore sync.Map
-func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest {
- messages := make([]BaiduMessage, 0, len(request.Messages))
+func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
+ messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else {
- messages = append(messages, BaiduMessage{
+ messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
- return &BaiduChatRequest{
+ return &ChatRequest{
Messages: messages,
Stream: request.Stream,
}
@@ -160,7 +161,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -171,7 +172,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
diff --git a/relay/channel/baidu/model.go b/relay/channel/baidu/model.go
index caaebafb..e182f5dd 100644
--- a/relay/channel/baidu/model.go
+++ b/relay/channel/baidu/model.go
@@ -13,7 +13,7 @@ type ChatResponse struct {
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage openai.Usage `json:"usage"`
- BaiduError
+ Error
}
type ChatStreamResponse struct {
@@ -38,7 +38,7 @@ type EmbeddingResponse struct {
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
Usage openai.Usage `json:"usage"`
- BaiduError
+ Error
}
type AccessToken struct {
diff --git a/relay/channel/google/adaptor.go b/relay/channel/google/adaptor.go
new file mode 100644
index 00000000..b328db32
--- /dev/null
+++ b/relay/channel/google/adaptor.go
@@ -0,0 +1,22 @@
+package google
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/google/gemini.go b/relay/channel/google/gemini.go
index f49caadf..0f4e606c 100644
--- a/relay/channel/google/gemini.go
+++ b/relay/channel/google/gemini.go
@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/common/image"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"strings"
@@ -237,7 +238,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
}
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
diff --git a/relay/channel/google/palm.go b/relay/channel/google/palm.go
index 77d8cbd6..c2518a07 100644
--- a/relay/channel/google/palm.go
+++ b/relay/channel/google/palm.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
)
@@ -78,20 +79,20 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
- common.SysError("error closing stream response: " + err.Error())
+ logger.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -103,7 +104,7 @@ func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithSt
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
diff --git a/relay/channel/interface.go b/relay/channel/interface.go
new file mode 100644
index 00000000..7a0fcbd3
--- /dev/null
+++ b/relay/channel/interface.go
@@ -0,0 +1,15 @@
+package channel
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor interface {
+ GetRequestURL() string
+ Auth(c *gin.Context) error
+ ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error)
+ DoRequest(request *openai.GeneralOpenAIRequest) error
+ DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error)
+}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
new file mode 100644
index 00000000..cc302611
--- /dev/null
+++ b/relay/channel/openai/adaptor.go
@@ -0,0 +1,21 @@
+package openai
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*ErrorWithStatusCode, *Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go
index 848a6fa4..5f464249 100644
--- a/relay/channel/openai/main.go
+++ b/relay/channel/openai/main.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/constant"
"strings"
)
@@ -46,7 +47,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
@@ -56,7 +57,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWi
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
for _, choice := range streamResponse.Choices {
diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go
index c831ce19..937fb424 100644
--- a/relay/channel/openai/model.go
+++ b/relay/channel/openai/model.go
@@ -207,6 +207,11 @@ type Usage struct {
TotalTokens int `json:"total_tokens"`
}
+type UsageOrResponseText struct {
+ *Usage
+ ResponseText string
+}
+
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go
index 4b40b228..b398c220 100644
--- a/relay/channel/openai/token.go
+++ b/relay/channel/openai/token.go
@@ -7,6 +7,7 @@ import (
"math"
"one-api/common"
"one-api/common/image"
+ "one-api/common/logger"
"strings"
)
@@ -15,15 +16,15 @@ var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
- common.SysLog("initializing token encoders")
+ logger.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
+ logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
@@ -34,7 +35,7 @@ func InitTokenEncoders() {
tokenEncoderMap[model] = nil
}
}
- common.SysLog("token encoders initialized")
+ logger.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
@@ -45,7 +46,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
+ logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
@@ -99,7 +100,7 @@ func CountTokenMessages(messages []Message, model string) int {
}
imageTokens, err := countImageTokens(url, detail)
if err != nil {
- common.SysError("error counting image tokens: " + err.Error())
+ logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
new file mode 100644
index 00000000..e9f86aff
--- /dev/null
+++ b/relay/channel/tencent/adaptor.go
@@ -0,0 +1,22 @@
+package tencent
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go
index 60e275a9..9203249a 100644
--- a/relay/channel/tencent/main.go
+++ b/relay/channel/tencent/main.go
@@ -12,6 +12,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"sort"
@@ -131,7 +132,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var TencentResponse ChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
@@ -140,7 +141,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
}
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
new file mode 100644
index 00000000..89e58485
--- /dev/null
+++ b/relay/channel/xunfei/adaptor.go
@@ -0,0 +1,22 @@
+package xunfei
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go
index 1cc0b664..1c55cc09 100644
--- a/relay/channel/xunfei/main.go
+++ b/relay/channel/xunfei/main.go
@@ -12,6 +12,7 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"strings"
@@ -140,7 +141,7 @@ func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appI
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -215,20 +216,20 @@ func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl,
for {
_, msg, err := conn.ReadMessage()
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ logger.SysError("error reading stream response: " + err.Error())
break
}
var response ChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
- common.SysError("error closing websocket connection: " + err.Error())
+ logger.SysError("error closing websocket connection: " + err.Error())
}
break
}
@@ -247,7 +248,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
}
if apiVersion == "" {
apiVersion = "v1.1"
- common.SysLog("api_version not found, use default: " + apiVersion)
+ logger.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion != "v1.1" {
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
new file mode 100644
index 00000000..6d901bc3
--- /dev/null
+++ b/relay/channel/zhipu/adaptor.go
@@ -0,0 +1,22 @@
+package zhipu
+
+import (
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/relay/channel/openai"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) Auth(c *gin.Context) error {
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(request *openai.GeneralOpenAIRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage, error) {
+ return nil, nil, nil
+}
diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go
index 3dc613a4..c818c80e 100644
--- a/relay/channel/zhipu/main.go
+++ b/relay/channel/zhipu/main.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"strings"
@@ -34,7 +35,7 @@ func GetToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
- common.SysError("invalid zhipu key: " + apikey)
+ logger.SysError("invalid zhipu key: " + apikey)
return ""
}
@@ -193,7 +194,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -202,13 +203,13 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var zhipuResponse StreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ logger.SysError("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
new file mode 100644
index 00000000..658bfb90
--- /dev/null
+++ b/relay/constant/api_type.go
@@ -0,0 +1,69 @@
+package constant
+
+import (
+ "one-api/common"
+)
+
+const (
+ APITypeOpenAI = iota
+ APITypeClaude
+ APITypePaLM
+ APITypeBaidu
+ APITypeZhipu
+ APITypeAli
+ APITypeXunfei
+ APITypeAIProxyLibrary
+ APITypeTencent
+ APITypeGemini
+)
+
+func ChannelType2APIType(channelType int) int {
+ apiType := APITypeOpenAI
+ switch channelType {
+ case common.ChannelTypeAnthropic:
+ apiType = APITypeClaude
+ case common.ChannelTypeBaidu:
+ apiType = APITypeBaidu
+ case common.ChannelTypePaLM:
+ apiType = APITypePaLM
+ case common.ChannelTypeZhipu:
+ apiType = APITypeZhipu
+ case common.ChannelTypeAli:
+ apiType = APITypeAli
+ case common.ChannelTypeXunfei:
+ apiType = APITypeXunfei
+ case common.ChannelTypeAIProxyLibrary:
+ apiType = APITypeAIProxyLibrary
+ case common.ChannelTypeTencent:
+ apiType = APITypeTencent
+ case common.ChannelTypeGemini:
+ apiType = APITypeGemini
+ }
+ return apiType
+}
+
+//func GetAdaptor(apiType int) channel.Adaptor {
+// switch apiType {
+// case APITypeOpenAI:
+// return &openai.Adaptor{}
+// case APITypeClaude:
+// return &anthropic.Adaptor{}
+// case APITypePaLM:
+// return &google.Adaptor{}
+// case APITypeZhipu:
+// return &baidu.Adaptor{}
+// case APITypeBaidu:
+// return &baidu.Adaptor{}
+// case APITypeAli:
+// return &ali.Adaptor{}
+// case APITypeXunfei:
+// return &xunfei.Adaptor{}
+// case APITypeAIProxyLibrary:
+// return &aiproxy.Adaptor{}
+// case APITypeTencent:
+// return &tencent.Adaptor{}
+// case APITypeGemini:
+// return &google.Adaptor{}
+// }
+// return nil
+//}
diff --git a/relay/constant/common.go b/relay/constant/common.go
new file mode 100644
index 00000000..b6606cc6
--- /dev/null
+++ b/relay/constant/common.go
@@ -0,0 +1,3 @@
+package constant
+
+var StopFinishReason = "stop"
diff --git a/relay/constant/main.go b/relay/constant/main.go
deleted file mode 100644
index b3aeaaff..00000000
--- a/relay/constant/main.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package constant
-
-const (
- RelayModeUnknown = iota
- RelayModeChatCompletions
- RelayModeCompletions
- RelayModeEmbeddings
- RelayModeModerations
- RelayModeImagesGenerations
- RelayModeEdits
- RelayModeAudioSpeech
- RelayModeAudioTranscription
- RelayModeAudioTranslation
-)
-
-var StopFinishReason = "stop"
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
new file mode 100644
index 00000000..5e2fe574
--- /dev/null
+++ b/relay/constant/relay_mode.go
@@ -0,0 +1,42 @@
+package constant
+
+import "strings"
+
+const (
+ RelayModeUnknown = iota
+ RelayModeChatCompletions
+ RelayModeCompletions
+ RelayModeEmbeddings
+ RelayModeModerations
+ RelayModeImagesGenerations
+ RelayModeEdits
+ RelayModeAudioSpeech
+ RelayModeAudioTranscription
+ RelayModeAudioTranslation
+)
+
+func Path2RelayMode(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(path, "/v1/chat/completions") {
+ relayMode = RelayModeChatCompletions
+ } else if strings.HasPrefix(path, "/v1/completions") {
+ relayMode = RelayModeCompletions
+ } else if strings.HasPrefix(path, "/v1/embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasSuffix(path, "embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasPrefix(path, "/v1/moderations") {
+ relayMode = RelayModeModerations
+ } else if strings.HasPrefix(path, "/v1/images/generations") {
+ relayMode = RelayModeImagesGenerations
+ } else if strings.HasPrefix(path, "/v1/edits") {
+ relayMode = RelayModeEdits
+ } else if strings.HasPrefix(path, "/v1/audio/speech") {
+ relayMode = RelayModeAudioSpeech
+ } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+ relayMode = RelayModeAudioTranscription
+ } else if strings.HasPrefix(path, "/v1/audio/translations") {
+ relayMode = RelayModeAudioTranslation
+ }
+ return relayMode
+}
diff --git a/relay/controller/audio.go b/relay/controller/audio.go
index 08d9af2a..d8a896de 100644
--- a/relay/controller/audio.go
+++ b/relay/controller/audio.go
@@ -11,6 +11,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/constant"
@@ -102,7 +103,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
- apiVersion := util.GetAPIVersion(c)
+ apiVersion := util.GetAzureAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
}
@@ -191,7 +192,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
+ logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
diff --git a/relay/controller/image.go b/relay/controller/image.go
index be5fc3dd..9502a4d7 100644
--- a/relay/controller/image.go
+++ b/relay/controller/image.go
@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/util"
@@ -112,7 +113,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
- apiVersion := util.GetAPIVersion(c)
+ apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
}
@@ -175,11 +176,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
}
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
+ logger.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
diff --git a/relay/controller/text.go b/relay/controller/text.go
index b17ff950..968cc751 100644
--- a/relay/controller/text.go
+++ b/relay/controller/text.go
@@ -1,206 +1,46 @@
package controller
import (
- "bytes"
"context"
- "encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
- "io"
"math"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
- "one-api/relay/channel/aiproxy"
- "one-api/relay/channel/ali"
- "one-api/relay/channel/anthropic"
- "one-api/relay/channel/baidu"
- "one-api/relay/channel/google"
"one-api/relay/channel/openai"
- "one-api/relay/channel/tencent"
- "one-api/relay/channel/xunfei"
- "one-api/relay/channel/zhipu"
"one-api/relay/constant"
"one-api/relay/util"
"strings"
)
-const (
- APITypeOpenAI = iota
- APITypeClaude
- APITypePaLM
- APITypeBaidu
- APITypeZhipu
- APITypeAli
- APITypeXunfei
- APITypeAIProxyLibrary
- APITypeTencent
- APITypeGemini
-)
-
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
- channelType := c.GetInt("channel")
- channelId := c.GetInt("channel_id")
- tokenId := c.GetInt("token_id")
- userId := c.GetInt("id")
- group := c.GetString("group")
+ ctx := c.Request.Context()
+ meta := util.GetRelayMeta(c)
var textRequest openai.GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
- if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
- return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
- }
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
- // request validation
- if textRequest.Model == "" {
- return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
+ err = util.ValidateTextRequest(&textRequest, relayMode)
+ if err != nil {
+ return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
- switch relayMode {
- case constant.RelayModeCompletions:
- if textRequest.Prompt == "" {
- return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
- }
- case constant.RelayModeChatCompletions:
- if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
- return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
- }
- case constant.RelayModeEmbeddings:
- case constant.RelayModeModerations:
- if textRequest.Input == "" {
- return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
- }
- case constant.RelayModeEdits:
- if textRequest.Instruction == "" {
- return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
- }
- }
- // map model name
- modelMapping := c.GetString("model_mapping")
- isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[textRequest.Model] != "" {
- textRequest.Model = modelMap[textRequest.Model]
- isModelMapped = true
- }
- }
- apiType := APITypeOpenAI
- switch channelType {
- case common.ChannelTypeAnthropic:
- apiType = APITypeClaude
- case common.ChannelTypeBaidu:
- apiType = APITypeBaidu
- case common.ChannelTypePaLM:
- apiType = APITypePaLM
- case common.ChannelTypeZhipu:
- apiType = APITypeZhipu
- case common.ChannelTypeAli:
- apiType = APITypeAli
- case common.ChannelTypeXunfei:
- apiType = APITypeXunfei
- case common.ChannelTypeAIProxyLibrary:
- apiType = APITypeAIProxyLibrary
- case common.ChannelTypeTencent:
- apiType = APITypeTencent
- case common.ChannelTypeGemini:
- apiType = APITypeGemini
- }
- baseURL := common.ChannelBaseURLs[channelType]
- requestURL := c.Request.URL.String()
- if c.GetString("base_url") != "" {
- baseURL = c.GetString("base_url")
- }
- fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
- switch apiType {
- case APITypeOpenAI:
- if channelType == common.ChannelTypeAzure {
- // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
- apiVersion := util.GetAPIVersion(c)
- requestURL := strings.Split(requestURL, "?")[0]
- requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
- baseURL = c.GetString("base_url")
- task := strings.TrimPrefix(requestURL, "/v1/")
- model_ := textRequest.Model
- model_ = strings.Replace(model_, ".", "", -1)
- // https://github.com/songquanpeng/one-api/issues/67
- model_ = strings.TrimSuffix(model_, "-0301")
- model_ = strings.TrimSuffix(model_, "-0314")
- model_ = strings.TrimSuffix(model_, "-0613")
-
- requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType)
- }
- case APITypeClaude:
- fullRequestURL = "https://api.anthropic.com/v1/complete"
- if baseURL != "" {
- fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
- }
- case APITypeBaidu:
- switch textRequest.Model {
- case "ERNIE-Bot":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
- case "ERNIE-Bot-turbo":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
- case "ERNIE-Bot-4":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
- case "BLOOMZ-7B":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
- case "Embedding-V1":
- fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
- }
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- var err error
- if apiKey, err = baidu.GetAccessToken(apiKey); err != nil {
- return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
- }
- fullRequestURL += "?access_token=" + apiKey
- case APITypePaLM:
- fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
- if baseURL != "" {
- fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
- }
- case APITypeGemini:
- requestBaseURL := "https://generativelanguage.googleapis.com"
- if baseURL != "" {
- requestBaseURL = baseURL
- }
- version := "v1"
- if c.GetString("api_version") != "" {
- version = c.GetString("api_version")
- }
- action := "generateContent"
- if textRequest.Stream {
- action = "streamGenerateContent"
- }
- fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
- case APITypeZhipu:
- method := "invoke"
- if textRequest.Stream {
- method = "sse-invoke"
- }
- fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
- case APITypeAli:
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
- if relayMode == constant.RelayModeEmbeddings {
- fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
- }
- case APITypeTencent:
- fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
- case APITypeAIProxyLibrary:
- fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
+ var isModelMapped bool
+ textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
+ apiType := constant.ChannelType2APIType(meta.ChannelType)
+ fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
+ if err != nil {
+ logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
+ return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
}
var promptTokens int
var completionTokens int
@@ -217,17 +57,17 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
- groupRatio := common.GetGroupRatio(group)
+ groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
- userQuota, err := model.CacheGetUserQuota(userId)
+ userQuota, err := model.CacheGetUserQuota(meta.UserId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
- err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
+ err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
@@ -235,165 +75,28 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
- common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
+ logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
}
if preConsumedQuota > 0 {
- err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+ err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
- var requestBody io.Reader
- if isModelMapped {
- jsonStr, err := json.Marshal(textRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- } else {
- requestBody = c.Request.Body
+ requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
+ if err != nil {
+ return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
}
- switch apiType {
- case APITypeClaude:
- claudeRequest := anthropic.ConvertRequest(textRequest)
- jsonStr, err := json.Marshal(claudeRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeBaidu:
- var jsonData []byte
- var err error
- switch relayMode {
- case constant.RelayModeEmbeddings:
- baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
- jsonData, err = json.Marshal(baiduEmbeddingRequest)
- default:
- baiduRequest := baidu.ConvertRequest(textRequest)
- jsonData, err = json.Marshal(baiduRequest)
- }
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonData)
- case APITypePaLM:
- palmRequest := google.ConvertPaLMRequest(textRequest)
- jsonStr, err := json.Marshal(palmRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeGemini:
- geminiChatRequest := google.ConvertGeminiRequest(textRequest)
- jsonStr, err := json.Marshal(geminiChatRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeZhipu:
- zhipuRequest := zhipu.ConvertRequest(textRequest)
- jsonStr, err := json.Marshal(zhipuRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeAli:
- var jsonStr []byte
- var err error
- switch relayMode {
- case constant.RelayModeEmbeddings:
- aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
- jsonStr, err = json.Marshal(aliEmbeddingRequest)
- default:
- aliRequest := ali.ConvertRequest(textRequest)
- jsonStr, err = json.Marshal(aliRequest)
- }
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeTencent:
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
- if err != nil {
- return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
- }
- tencentRequest := tencent.ConvertRequest(textRequest)
- tencentRequest.AppId = appId
- tencentRequest.SecretId = secretId
- jsonStr, err := json.Marshal(tencentRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- sign := tencent.GetSign(*tencentRequest, secretKey)
- c.Request.Header.Set("Authorization", sign)
- requestBody = bytes.NewBuffer(jsonStr)
- case APITypeAIProxyLibrary:
- aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
- aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
- jsonStr, err := json.Marshal(aiProxyLibraryRequest)
- if err != nil {
- return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonStr)
- }
-
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
- if apiType != APITypeXunfei { // cause xunfei use websocket
+ if apiType != constant.APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
- apiKey := c.Request.Header.Get("Authorization")
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- switch apiType {
- case APITypeOpenAI:
- if channelType == common.ChannelTypeAzure {
- req.Header.Set("api-key", apiKey)
- } else {
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
- if channelType == common.ChannelTypeOpenRouter {
- req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
- req.Header.Set("X-Title", "One API")
- }
- }
- case APITypeClaude:
- req.Header.Set("x-api-key", apiKey)
- anthropicVersion := c.Request.Header.Get("anthropic-version")
- if anthropicVersion == "" {
- anthropicVersion = "2023-06-01"
- }
- req.Header.Set("anthropic-version", anthropicVersion)
- case APITypeZhipu:
- token := zhipu.GetToken(apiKey)
- req.Header.Set("Authorization", token)
- case APITypeAli:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- if textRequest.Stream {
- req.Header.Set("X-DashScope-SSE", "enable")
- }
- if c.GetString("plugin") != "" {
- req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
- }
- case APITypeTencent:
- req.Header.Set("Authorization", apiKey)
- case APITypePaLM:
- req.Header.Set("x-goog-api-key", apiKey)
- case APITypeGemini:
- req.Header.Set("x-goog-api-key", apiKey)
- default:
- req.Header.Set("Authorization", "Bearer "+apiKey)
- }
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- if isStream && c.Request.Header.Get("Accept") == "" {
- req.Header.Set("Accept", "text/event-stream")
- }
- //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
+ SetupRequestHeaders(c, req, apiType, meta, isStream)
resp, err = util.HTTPClient.Do(req)
if err != nil {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@@ -409,29 +112,31 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
- if preConsumedQuota != 0 {
- go func(ctx context.Context) {
- // return pre-consumed quota
- err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
- if err != nil {
- common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
- }
- }(c.Request.Context())
- }
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
}
}
- var textResponse openai.SlimTextResponse
- tokenName := c.GetString("token_name")
+ var respErr *openai.ErrorWithStatusCode
+ var usage *openai.Usage
defer func(ctx context.Context) {
- // c.Writer.Flush()
+ // Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
+ if respErr != nil {
+ logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
+ util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ return
+ }
+ if usage == nil {
+ logger.Error(ctx, "usage is nil, which is unexpected")
+ return
+ }
+
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
- promptTokens = textResponse.Usage.PromptTokens
- completionTokens = textResponse.Usage.CompletionTokens
+ promptTokens = usage.PromptTokens
+ completionTokens = usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
@@ -443,239 +148,25 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
quota = 0
}
quotaDelta := quota - preConsumedQuota
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+ err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ logger.Error(ctx, "error consuming token remain quota: "+err.Error())
}
- err = model.CacheUpdateUserQuota(userId)
+ err = model.CacheUpdateUserQuota(meta.UserId)
if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
+ logger.Error(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- model.UpdateChannelUsedQuota(channelId, quota)
+ model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
+ model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
-
}()
- }(c.Request.Context())
- switch apiType {
- case APITypeOpenAI:
- if isStream {
- err, responseText := openai.StreamHandler(c, resp, relayMode)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeClaude:
- if isStream {
- err, responseText := anthropic.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeBaidu:
- if isStream {
- err, usage := baidu.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- var err *openai.ErrorWithStatusCode
- var usage *openai.Usage
- switch relayMode {
- case constant.RelayModeEmbeddings:
- err, usage = baidu.EmbeddingHandler(c, resp)
- default:
- err, usage = baidu.Handler(c, resp)
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypePaLM:
- if textRequest.Stream { // PaLM2 API does not support stream
- err, responseText := google.PaLMStreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeGemini:
- if textRequest.Stream {
- err, responseText := google.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeZhipu:
- if isStream {
- err, usage := zhipu.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- // zhipu's API does not return prompt tokens & completion tokens
- textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
- return nil
- } else {
- err, usage := zhipu.Handler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- // zhipu's API does not return prompt tokens & completion tokens
- textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
- return nil
- }
- case APITypeAli:
- if isStream {
- err, usage := ali.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- var err *openai.ErrorWithStatusCode
- var usage *openai.Usage
- switch relayMode {
- case constant.RelayModeEmbeddings:
- err, usage = ali.EmbeddingHandler(c, resp)
- default:
- err, usage = ali.Handler(c, resp)
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeXunfei:
- auth := c.Request.Header.Get("Authorization")
- auth = strings.TrimPrefix(auth, "Bearer ")
- splits := strings.Split(auth, "|")
- if len(splits) != 3 {
- return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
- }
- var err *openai.ErrorWithStatusCode
- var usage *openai.Usage
- if isStream {
- err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2])
- } else {
- err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2])
- }
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- case APITypeAIProxyLibrary:
- if isStream {
- err, usage := aiproxy.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- } else {
- err, usage := aiproxy.Handler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- case APITypeTencent:
- if isStream {
- err, responseText := tencent.StreamHandler(c, resp)
- if err != nil {
- return err
- }
- textResponse.Usage.PromptTokens = promptTokens
- textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
- return nil
- } else {
- err, usage := tencent.Handler(c, resp)
- if err != nil {
- return err
- }
- if usage != nil {
- textResponse.Usage = *usage
- }
- return nil
- }
- default:
- return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
+ }(ctx)
+ usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
+ if respErr != nil {
+ return respErr
}
+ return nil
}
diff --git a/relay/controller/util.go b/relay/controller/util.go
new file mode 100644
index 00000000..cdb10dbf
--- /dev/null
+++ b/relay/controller/util.go
@@ -0,0 +1,336 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/relay/channel/aiproxy"
+ "one-api/relay/channel/ali"
+ "one-api/relay/channel/anthropic"
+ "one-api/relay/channel/baidu"
+ "one-api/relay/channel/google"
+ "one-api/relay/channel/openai"
+ "one-api/relay/channel/tencent"
+ "one-api/relay/channel/xunfei"
+ "one-api/relay/channel/zhipu"
+ "one-api/relay/constant"
+ "one-api/relay/util"
+ "strings"
+)
+
+func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
+ fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
+ switch apiType {
+ case constant.APITypeOpenAI:
+ if meta.ChannelType == common.ChannelTypeAzure {
+ // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+ requestURL := strings.Split(requestURL, "?")[0]
+ requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
+ task := strings.TrimPrefix(requestURL, "/v1/")
+ model_ := textRequest.Model
+ model_ = strings.Replace(model_, ".", "", -1)
+ // https://github.com/songquanpeng/one-api/issues/67
+ model_ = strings.TrimSuffix(model_, "-0301")
+ model_ = strings.TrimSuffix(model_, "-0314")
+ model_ = strings.TrimSuffix(model_, "-0613")
+
+ requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+ fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
+ }
+ case constant.APITypeClaude:
+ fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
+ case constant.APITypeBaidu:
+ switch textRequest.Model {
+ case "ERNIE-Bot":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
+ case "ERNIE-Bot-turbo":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+ case "ERNIE-Bot-4":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
+ case "BLOOMZ-7B":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+ case "Embedding-V1":
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
+ }
+ var accessToken string
+ var err error
+ if accessToken, err = baidu.GetAccessToken(meta.APIKey); err != nil {
+ return "", fmt.Errorf("failed to get baidu access token: %w", err)
+ }
+ fullRequestURL += "?access_token=" + accessToken
+ case constant.APITypePaLM:
+ fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL)
+ case constant.APITypeGemini:
+ version := common.AssignOrDefault(meta.APIVersion, "v1")
+ action := "generateContent"
+ if textRequest.Stream {
+ action = "streamGenerateContent"
+ }
+ fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, textRequest.Model, action)
+ case constant.APITypeZhipu:
+ method := "invoke"
+ if textRequest.Stream {
+ method = "sse-invoke"
+ }
+ fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
+ case constant.APITypeAli:
+ fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
+ if relayMode == constant.RelayModeEmbeddings {
+ fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
+ }
+ case constant.APITypeTencent:
+ fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
+ case constant.APITypeAIProxyLibrary:
+ fullRequestURL = fmt.Sprintf("%s/api/library/ask", meta.BaseURL)
+ }
+ return fullRequestURL, nil
+}
+
+func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
+ var requestBody io.Reader
+ if isModelMapped {
+ jsonStr, err := json.Marshal(textRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ } else {
+ requestBody = c.Request.Body
+ }
+ switch apiType {
+ case constant.APITypeClaude:
+ claudeRequest := anthropic.ConvertRequest(textRequest)
+ jsonStr, err := json.Marshal(claudeRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeBaidu:
+ var jsonData []byte
+ var err error
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
+ jsonData, err = json.Marshal(baiduEmbeddingRequest)
+ default:
+ baiduRequest := baidu.ConvertRequest(textRequest)
+ jsonData, err = json.Marshal(baiduRequest)
+ }
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ case constant.APITypePaLM:
+ palmRequest := google.ConvertPaLMRequest(textRequest)
+ jsonStr, err := json.Marshal(palmRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeGemini:
+ geminiChatRequest := google.ConvertGeminiRequest(textRequest)
+ jsonStr, err := json.Marshal(geminiChatRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeZhipu:
+ zhipuRequest := zhipu.ConvertRequest(textRequest)
+ jsonStr, err := json.Marshal(zhipuRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeAli:
+ var jsonStr []byte
+ var err error
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
+ jsonStr, err = json.Marshal(aliEmbeddingRequest)
+ default:
+ aliRequest := ali.ConvertRequest(textRequest)
+ jsonStr, err = json.Marshal(aliRequest)
+ }
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeTencent:
+ apiKey := c.Request.Header.Get("Authorization")
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
+ if err != nil {
+ return nil, err
+ }
+ tencentRequest := tencent.ConvertRequest(textRequest)
+ tencentRequest.AppId = appId
+ tencentRequest.SecretId = secretId
+ jsonStr, err := json.Marshal(tencentRequest)
+ if err != nil {
+ return nil, err
+ }
+ sign := tencent.GetSign(*tencentRequest, secretKey)
+ c.Request.Header.Set("Authorization", sign)
+ requestBody = bytes.NewBuffer(jsonStr)
+ case constant.APITypeAIProxyLibrary:
+ aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
+ aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
+ jsonStr, err := json.Marshal(aiProxyLibraryRequest)
+ if err != nil {
+ return nil, err
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ }
+ return requestBody, nil
+}
+
+func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
+ SetupAuthHeaders(c, req, apiType, meta, isStream)
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ if isStream && c.Request.Header.Get("Accept") == "" {
+ req.Header.Set("Accept", "text/event-stream")
+ }
+}
+
+func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
+ apiKey := meta.APIKey
+ switch apiType {
+ case constant.APITypeOpenAI:
+ if meta.ChannelType == common.ChannelTypeAzure {
+ req.Header.Set("api-key", apiKey)
+ } else {
+ req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+ if meta.ChannelType == common.ChannelTypeOpenRouter {
+ req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+ req.Header.Set("X-Title", "One API")
+ }
+ }
+ case constant.APITypeClaude:
+ req.Header.Set("x-api-key", apiKey)
+ anthropicVersion := c.Request.Header.Get("anthropic-version")
+ if anthropicVersion == "" {
+ anthropicVersion = "2023-06-01"
+ }
+ req.Header.Set("anthropic-version", anthropicVersion)
+ case constant.APITypeZhipu:
+ token := zhipu.GetToken(apiKey)
+ req.Header.Set("Authorization", token)
+ case constant.APITypeAli:
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+ if isStream {
+ req.Header.Set("X-DashScope-SSE", "enable")
+ }
+ if c.GetString("plugin") != "" {
+ req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
+ }
+ case constant.APITypeTencent:
+ req.Header.Set("Authorization", apiKey)
+ case constant.APITypePaLM:
+ req.Header.Set("x-goog-api-key", apiKey)
+ case constant.APITypeGemini:
+ req.Header.Set("x-goog-api-key", apiKey)
+ default:
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+ }
+}
+
+func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
+ var responseText string
+ switch apiType {
+ case constant.APITypeOpenAI:
+ if isStream {
+ err, responseText = openai.StreamHandler(c, resp, relayMode)
+ } else {
+ err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeClaude:
+ if isStream {
+ err, responseText = anthropic.StreamHandler(c, resp)
+ } else {
+ err, usage = anthropic.Handler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeBaidu:
+ if isStream {
+ err, usage = baidu.StreamHandler(c, resp)
+ } else {
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ err, usage = baidu.EmbeddingHandler(c, resp)
+ default:
+ err, usage = baidu.Handler(c, resp)
+ }
+ }
+ case constant.APITypePaLM:
+ if isStream { // PaLM2 API does not support stream
+ err, responseText = google.PaLMStreamHandler(c, resp)
+ } else {
+ err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeGemini:
+ if isStream {
+ err, responseText = google.StreamHandler(c, resp)
+ } else {
+ err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
+ }
+ case constant.APITypeZhipu:
+ if isStream {
+ err, usage = zhipu.StreamHandler(c, resp)
+ } else {
+ err, usage = zhipu.Handler(c, resp)
+ }
+ case constant.APITypeAli:
+ if isStream {
+ err, usage = ali.StreamHandler(c, resp)
+ } else {
+ switch relayMode {
+ case constant.RelayModeEmbeddings:
+ err, usage = ali.EmbeddingHandler(c, resp)
+ default:
+ err, usage = ali.Handler(c, resp)
+ }
+ }
+ case constant.APITypeXunfei:
+ auth := c.Request.Header.Get("Authorization")
+ auth = strings.TrimPrefix(auth, "Bearer ")
+ splits := strings.Split(auth, "|")
+ if len(splits) != 3 {
+ return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+ }
+ if isStream {
+ err, usage = xunfei.StreamHandler(c, *textRequest, splits[0], splits[1], splits[2])
+ } else {
+ err, usage = xunfei.Handler(c, *textRequest, splits[0], splits[1], splits[2])
+ }
+ case constant.APITypeAIProxyLibrary:
+ if isStream {
+ err, usage = aiproxy.StreamHandler(c, resp)
+ } else {
+ err, usage = aiproxy.Handler(c, resp)
+ }
+ case constant.APITypeTencent:
+ if isStream {
+ err, responseText = tencent.StreamHandler(c, resp)
+ } else {
+ err, usage = tencent.Handler(c, resp)
+ }
+ default:
+ return nil, openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if usage == nil && responseText != "" {
+ usage = &openai.Usage{}
+ usage.PromptTokens = promptTokens
+ usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ }
+ return usage, nil
+}
diff --git a/relay/util/billing.go b/relay/util/billing.go
new file mode 100644
index 00000000..35fb28a4
--- /dev/null
+++ b/relay/util/billing.go
@@ -0,0 +1,19 @@
+package util
+
+import (
+ "context"
+ "one-api/common/logger"
+ "one-api/model"
+)
+
+func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
+ if preConsumedQuota != 0 {
+ go func(ctx context.Context) {
+ // return pre-consumed quota
+ err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
+ if err != nil {
+ logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
+ }
+ }(ctx)
+ }
+}
diff --git a/relay/util/common.go b/relay/util/common.go
index 9d13b12e..d7596188 100644
--- a/relay/util/common.go
+++ b/relay/util/common.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"strconv"
@@ -138,11 +139,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
+ logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
@@ -152,11 +153,11 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
- common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
+ logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}
-func GetAPIVersion(c *gin.Context) string {
+func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go
new file mode 100644
index 00000000..39e062a1
--- /dev/null
+++ b/relay/util/model_mapping.go
@@ -0,0 +1,12 @@
+package util
+
+func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
+ if mapping == nil {
+ return modelName, false
+ }
+ mappedModelName := mapping[modelName]
+ if mappedModelName != "" {
+ return mappedModelName, true
+ }
+ return modelName, false
+}
diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go
new file mode 100644
index 00000000..19936e49
--- /dev/null
+++ b/relay/util/relay_meta.go
@@ -0,0 +1,44 @@
+package util
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ "strings"
+)
+
+type RelayMeta struct {
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenName string
+ UserId int
+ Group string
+ ModelMapping map[string]string
+ BaseURL string
+ APIVersion string
+ APIKey string
+ Config map[string]string
+}
+
+func GetRelayMeta(c *gin.Context) *RelayMeta {
+ meta := RelayMeta{
+ ChannelType: c.GetInt("channel"),
+ ChannelId: c.GetInt("channel_id"),
+ TokenId: c.GetInt("token_id"),
+ TokenName: c.GetString("token_name"),
+ UserId: c.GetInt("id"),
+ Group: c.GetString("group"),
+ ModelMapping: c.GetStringMapString("model_mapping"),
+ BaseURL: c.GetString("base_url"),
+ APIVersion: c.GetString("api_version"),
+ APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+ Config: nil,
+ }
+ if meta.ChannelType == common.ChannelTypeAzure {
+ meta.APIVersion = GetAzureAPIVersion(c)
+ }
+ if meta.BaseURL == "" {
+ meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
+ }
+ return &meta
+}
diff --git a/relay/util/validation.go b/relay/util/validation.go
new file mode 100644
index 00000000..48b42d94
--- /dev/null
+++ b/relay/util/validation.go
@@ -0,0 +1,37 @@
+package util
+
+import (
+ "errors"
+ "math"
+ "one-api/relay/channel/openai"
+ "one-api/relay/constant"
+)
+
+func ValidateTextRequest(textRequest *openai.GeneralOpenAIRequest, relayMode int) error {
+ if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
+ return errors.New("max_tokens is invalid")
+ }
+ if textRequest.Model == "" {
+ return errors.New("model is required")
+ }
+ switch relayMode {
+ case constant.RelayModeCompletions:
+ if textRequest.Prompt == "" {
+ return errors.New("field prompt is required")
+ }
+ case constant.RelayModeChatCompletions:
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+ return errors.New("field messages is required")
+ }
+ case constant.RelayModeEmbeddings:
+ case constant.RelayModeModerations:
+ if textRequest.Input == "" {
+ return errors.New("field input is required")
+ }
+ case constant.RelayModeEdits:
+ if textRequest.Instruction == "" {
+ return errors.New("field instruction is required")
+ }
+ }
+ return nil
+}
diff --git a/router/main.go b/router/main.go
index 85127a1a..733a1033 100644
--- a/router/main.go
+++ b/router/main.go
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
+ "one-api/common/logger"
"os"
"strings"
)
@@ -17,7 +18,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) {
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
- common.SysLog("FRONTEND_BASE_URL is ignored on master node")
+ logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
}
if frontendBaseUrl == "" {
SetWebRouter(router, buildFS)