diff --git a/Dockerfile b/Dockerfile index 96def4b2..21d0f779 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,9 @@ WORKDIR /web/air RUN npm install RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build -FROM golang:1.22.2-bullseye AS builder2 +FROM golang:alpine AS builder2 + +RUN apk add --no-cache g++ ENV GO111MODULE=on \ CGO_ENABLED=1 \ @@ -23,7 +25,7 @@ ADD go.mod go.sum ./ RUN go mod download COPY . . COPY --from=builder /web/build ./web/build -RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api +RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api FROM debian:bullseye diff --git a/README.en.md b/README.en.md index db96a858..c9fdbbc8 100644 --- a/README.en.md +++ b/README.en.md @@ -245,16 +245,41 @@ If the channel ID is not provided, load balancing will be used to distribute the + Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs` 5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` -6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. +6. 'MEMORY_CACHE_ENABLED': Enabling memory caching can cause a certain delay in updating user quotas, with optional values of 'true' and 'false'. If not set, it defaults to 'false'. +7. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. + Example: `SYNC_FREQUENCY=60` -7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. +8. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. + Example: `NODE_TYPE=slave` -8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. +9. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. + Example: `CHANNEL_UPDATE_FREQUENCY=1440` -9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. +10. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. + Example: `CHANNEL_TEST_FREQUENCY=1440` -10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. +11. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. + Example: `POLLING_INTERVAL=5` +12. `BATCH_UPDATE_ENABLED`: Enabling batch database update aggregation can cause a certain delay in updating user quotas. The optional values are 'true' and 'false', but if not set, it defaults to 'false'. + +Example: ` BATCH_UPDATE_ENABLED=true` + +If you encounter an issue with too many database connections, you can try enabling this option. +13. `BATCH_UPDATE_INTERVAL=5`: The time interval for batch updating aggregates, measured in seconds, defaults to '5'. + +Example: ` BATCH_UPDATE_INTERVAL=5` +14. Request frequency limit: + + `GLOBAL_API_RATE_LIMIT`: Global API rate limit (excluding relay requests), the maximum number of requests within three minutes per IP, default to 180. + + `GLOBAL_WEL_RATE_LIMIT`: Global web speed limit, the maximum number of requests within three minutes per IP, default to 60. +15. Encoder cache settings: + +`TIKTOKEN_CACHE_DIR`: By default, when the program starts, it will download the encoding of some common word elements online, such as' gpt-3.5 turbo '. In some unstable network environments or offline situations, it may cause startup problems. This directory can be configured to cache data and can be migrated to an offline environment. + +`DATA_GYM_CACHE_DIR`: Currently, this configuration has the same function as' TIKTOKEN-CACHE-DIR ', but its priority is not as high as it. +16. `RELAY_TIMEOUT`: Relay timeout setting, measured in seconds, with no default timeout time set. +17. `RELAY_PROXY`: After setting up, use this proxy to request APIs. +18. `USER_CONTENT_REQUEST_TIMEOUT`: The timeout period for users to upload and download content, measured in seconds. +19. `USER_CONTENT_REQUEST_PROXY`: After setting up, use this agent to request content uploaded by users, such as images. +20. `SQLITE_BUSY_TIMEOUT`: SQLite lock wait timeout setting, measured in milliseconds, default to '3000'. +21. `GEMINI_SAFETY_SETTING`: Gemini's security settings are set to 'BLOCK-NONE' by default. +22. `GEMINI_VERSION`: The Gemini version used by the One API, which defaults to 'v1'. +23. `THE`: The system's theme setting, default to 'default', specific optional values refer to [here] (./web/README. md). +24. `ENABLE_METRIC`: Whether to disable channels based on request success rate, default not enabled, optional values are 'true' and 'false'. +25. `METRIC_QUEUE_SIZE`: Request success rate statistics queue size, default to '10'. +26. `METRIC_SUCCESS_RATE_THRESHOLD`: Request success rate threshold, default to '0.8'. +27. `INITIAL_ROOT_TOKEN`: If this value is set, a root user token with the value of the environment variable will be automatically created when the system starts for the first time. +28. `INITIAL_ROOT_ACCESS_TOKEN`: If this value is set, a system management token will be automatically created for the root user with a value of the environment variable when the system starts for the first time. ### Command Line Parameters 1. `--port `: Specifies the port number on which the server listens. Defaults to `3000`. diff --git a/README.md b/README.md index 78cd1353..466aa54a 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ docker image: `ppcelery/one-api:latest` + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) + [x] [DeepL](https://www.deepl.com/) + [x] [together.ai](https://www.together.ai/) + + [x] [novita.ai](https://www.novita.ai/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 @@ -365,33 +366,34 @@ graph LR + 例子:`NODE_TYPE=slave` 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 -11. 例子:`CHANNEL_TEST_FREQUENCY=1440` -12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + +例子:`CHANNEL_TEST_FREQUENCY=1440` +11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` -13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -15. 请求频率限制: +14. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -16. 编码器缓存设置: +15. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -18. `RELAY_PROXY`:设置后使用该代理来请求 API。 -19. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 -20. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 -21. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -22. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -23. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 -24. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -25. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -26. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -27. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 -28. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 +16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +17. `RELAY_PROXY`:设置后使用该代理来请求 API。 +18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 +19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 +20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 +23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 +28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/config/config.go b/common/config/config.go index a596664d..eb38af41 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -158,8 +158,12 @@ var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") +var InitialRootAccessToken = os.Getenv("INITIAL_ROOT_ACCESS_TOKEN") + var GeminiVersion = env.String("GEMINI_VERSION", "v1") +var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) + var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) diff --git a/common/logger/logger.go b/common/logger/logger.go index f725c619..d1022932 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -27,7 +27,12 @@ var setupLogOnce sync.Once func SetupLogger() { setupLogOnce.Do(func() { if LogDir != "" { - logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + var logPath string + if config.OnlyOneLogFile { + logPath = filepath.Join(LogDir, "oneapi.log") + } else { + logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + } fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") diff --git a/common/message/email.go b/common/message/email.go index 585aa37a..f7e99da6 100644 --- a/common/message/email.go +++ b/common/message/email.go @@ -5,14 +5,18 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "github.com/Laisky/errors/v2" + "github.com/songquanpeng/one-api/common/config" + "net" "net/smtp" "strings" "time" - - "github.com/Laisky/errors/v2" - "github.com/songquanpeng/one-api/common/config" ) +func shouldAuth() bool { + return config.SMTPAccount != "" || config.SMTPToken != "" +} + func SendEmail(subject string, receiver string, content string) error { if receiver == "" { return errors.Errorf("receiver is empty") @@ -43,16 +47,24 @@ func SendEmail(subject string, receiver string, content string) error { "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if config.SMTPPort == 465 { - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, - ServerName: config.SMTPServer, + if config.SMTPPort == 465 || !shouldAuth() { + // need advanced client + var conn net.Conn + var err error + if config.SMTPPort == 465 { + tlsConfig := &tls.Config{ + InsecureSkipVerify: false, + ServerName: config.SMTPServer, + } + conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) + } else { + conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } @@ -61,8 +73,10 @@ func SendEmail(subject string, receiver string, content string) error { return err } defer client.Close() - if err = client.Auth(auth); err != nil { - return err + if shouldAuth() { + if err = client.Auth(auth); err != nil { + return err + } } if err = client.Mail(config.SMTPFrom); err != nil { return err diff --git a/controller/channel-test.go b/controller/channel-test.go index ad74a798..9d403440 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -29,11 +29,13 @@ import ( "time" ) -func buildTestRequest() *relaymodel.GeneralOpenAIRequest { +func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { + if model == "" { + model = "gpt-3.5-turbo" + } testRequest := &relaymodel.GeneralOpenAIRequest{ MaxTokens: 2, - Stream: false, - Model: "gpt-3.5-turbo", + Model: model, } testMessage := relaymodel.Message{ Role: "user", @@ -43,7 +45,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest { return testRequest } -func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { +func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -66,12 +68,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error return errors.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) - var modelName string - modelList := adaptor.GetModelList() + modelName := request.Model modelMap := channel.GetModelMapping() - if len(modelList) != 0 { - modelName = modelList[0] - } if modelName == "" || !strings.Contains(channel.Models, modelName) { modelNames := strings.Split(channel.Models, ",") if len(modelNames) > 0 { @@ -81,9 +79,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error modelName = modelMap[modelName] } } - request := buildTestRequest() + meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName - meta.OriginModelName, meta.ActualModelName = modelName, modelName convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil @@ -137,10 +134,15 @@ func TestChannel(c *gin.Context) { }) return } + model := c.Query("model") + testRequest := buildTestRequest(model) tik := time.Now() - err, _ = testChannel(channel) + err, _ = testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() + if err != nil { + milliseconds = 0 + } go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 if err != nil { @@ -148,6 +150,7 @@ func TestChannel(c *gin.Context) { "success": false, "message": err.Error(), "time": consumedTime, + "model": model, }) return } @@ -155,6 +158,7 @@ func TestChannel(c *gin.Context) { "success": true, "message": "", "time": consumedTime, + "model": model, }) return } @@ -185,11 +189,12 @@ func testChannels(notify bool, scope string) error { for _, channel := range channels { isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel) + testRequest := buildTestRequest("") + err, openaiErr := testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) if config.AutomaticDisableChannelEnabled { monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } else { diff --git a/main.go b/main.go index 1161d00b..4f88607f 100644 --- a/main.go +++ b/main.go @@ -29,27 +29,19 @@ func main() { common.Init() logger.SetupLogger() logger.SysLogf("One API %s started", common.Version) - if os.Getenv("GIN_MODE") != "debug" { + + if os.Getenv("GIN_MODE") != gin.DebugMode { gin.SetMode(gin.ReleaseMode) } if config.DebugEnabled { logger.SysLog("running in debug mode") } - var err error + // Initialize SQL Database - model.DB, err = model.InitDB("SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize database: " + err.Error()) - } - if os.Getenv("LOG_SQL_DSN") != "" { - logger.SysLog("using secondary database for table logs") - model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize secondary database: " + err.Error()) - } - } else { - model.LOG_DB = model.DB - } + model.InitDB() + model.InitLogDB() + + var err error err = model.CreateRootAccountIfNeed() if err != nil { logger.FatalLog("database init error: " + err.Error()) diff --git a/model/main.go b/model/main.go index cab9e5ee..92e861c7 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "database/sql" "fmt" "os" "strings" @@ -31,13 +32,17 @@ func CreateRootAccountIfNeed() error { if err != nil { return errors.WithStack(err) } + accessToken := random.GetUUID() + if config.InitialRootAccessToken != "" { + accessToken = config.InitialRootAccessToken + } rootUser := User{ Username: "root", Password: hashedPassword, Role: RoleRootUser, Status: UserStatusEnabled, DisplayName: "Root User", - AccessToken: random.GetUUID(), + AccessToken: accessToken, Quota: 500000000000000, } DB.Create(&rootUser) @@ -62,99 +67,156 @@ func CreateRootAccountIfNeed() error { } func chooseDB(envName string) (*gorm.DB, error) { - if os.Getenv(envName) != "" { - dsn := os.Getenv(envName) - if strings.HasPrefix(dsn, "postgres://") { - // Use PostgreSQL - logger.SysLog("using PostgreSQL as database") - common.UsingPostgreSQL = true - return gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - PreferSimpleProtocol: true, // disables implicit prepared statement usage - }), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) - } + dsn := os.Getenv(envName) + + switch { + case strings.HasPrefix(dsn, "postgres://"): + // Use PostgreSQL + return openPostgreSQL(dsn) + case dsn != "": // Use MySQL - logger.SysLog("using MySQL as database") - common.UsingMySQL = true - return gorm.Open(mysql.Open(dsn), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) + return openMySQL(dsn) + default: + // Use SQLite + return openSQLite() } - // Use SQLite - logger.SysLog("SQL_DSN not set, using SQLite as database") - common.UsingSQLite = true - config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) - return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ +} + +func openPostgreSQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true + return gorm.Open(postgres.New(postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, // disables implicit prepared statement usage + }), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } -func InitDB(envName string) (db *gorm.DB, err error) { - db, err = chooseDB(envName) +func openMySQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using MySQL as database") + common.UsingMySQL = true + return gorm.Open(mysql.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} - if config.DebugEnabled { +func openSQLite() (*gorm.DB, error) { + logger.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func InitDB() { + var err error + DB, err = chooseDB("SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize database: " + err.Error()) + return + } + + sqlDB := setDBConns(DB) + + if !config.IsMasterNode { + return + } + + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + } + + logger.SysLog("database migration started") + if err = migrateDB(); err != nil { + logger.FatalLog("failed to migrate database: " + err.Error()) + return + } + logger.SysLog("database migrated") +} + +func migrateDB() error { + var err error + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Token{}); err != nil { + return err + } + if err = DB.AutoMigrate(&User{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Option{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Redemption{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Ability{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Log{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + return nil +} + +func InitLogDB() { + if os.Getenv("LOG_SQL_DSN") == "" { + LOG_DB = DB + return + } + + logger.SysLog("using secondary database for table logs") + var err error + LOG_DB, err = chooseDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + return + } + + setDBConns(LOG_DB) + + if !config.IsMasterNode { + return + } + + logger.SysLog("secondary database migration started") + err = migrateLOGDB() + if err != nil { + logger.FatalLog("failed to migrate secondary database: " + err.Error()) + return + } + logger.SysLog("secondary database migrated") +} + +func migrateLOGDB() error { + var err error + if err = LOG_DB.AutoMigrate(&Log{}); err != nil { + return err + } + return nil +} + +func setDBConns(db *gorm.DB) *sql.DB { + if config.DebugSQLEnabled { db = db.Debug() } - if err == nil { - if config.DebugSQLEnabled { - db = db.Debug() - } - sqlDB, err := db.DB() - if err != nil { - return nil, err - } - sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) - - if !config.IsMasterNode { - return db, err - } - if common.UsingMySQL { - _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded - } - logger.SysLog("database migration started") - err = db.AutoMigrate(&Channel{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&UserRequestCost{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Token{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&User{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Option{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Redemption{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Ability{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Log{}) - if err != nil { - return nil, err - } - logger.SysLog("database migrated") - return db, err - } else { - logger.FatalLog(err) + sqlDB, err := db.DB() + if err != nil { + logger.FatalLog("failed to connect database: " + err.Error()) + return nil } - return db, err + + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) + return sqlDB } func closeDB(db *gorm.DB) error { diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adaptor.go similarity index 71% rename from relay/adaptor/aws/adapter.go rename to relay/adaptor/aws/adaptor.go index 4c05781b..e1bf3b30 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adaptor.go @@ -4,14 +4,13 @@ import ( "io" "net/http" + "github.com/Laisky/errors/v2" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/gin-gonic/gin" - "github.com/pkg/errors" - "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/relay/adaptor" - "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" ) @@ -19,18 +18,52 @@ import ( var _ adaptor.Adaptor = new(Adaptor) type Adaptor struct { - meta *meta.Meta - awsClient *bedrockruntime.Client + awsAdapter utils.AwsAdapter + + Meta *meta.Meta + AwsClient *bedrockruntime.Client } func (a *Adaptor) Init(meta *meta.Meta) { - a.meta = meta - a.awsClient = bedrockruntime.New(bedrockruntime.Options{ + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ Region: meta.Config.Region, Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), }) } +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + adaptor := GetAdaptor(request.Model) + if adaptor == nil { + return nil, errors.New("adaptor not found") + } + + a.awsAdapter = adaptor + return adaptor.ConvertRequest(c, relayMode, request) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if a.awsAdapter == nil { + return nil, utils.WrapErr(errors.New("awsAdapter is nil")) + } + return a.awsAdapter.DoResponse(c, a.AwsClient, meta) +} + +func (a *Adaptor) GetModelList() (models []string) { + for model := range adaptors { + models = append(models, model) + } + return +} + +func (a *Adaptor) GetChannelName() string { + return "aws" +} + func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return "", nil } @@ -39,17 +72,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - - claudeReq := anthropic.ConvertRequest(*request) - c.Set(ctxkey.RequestModel, request.Model) - c.Set(ctxkey.ConvertedRequest, claudeReq) - return claudeReq, nil -} - func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") @@ -60,23 +82,3 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return nil, nil } - -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, a.awsClient) - } else { - err, usage = Handler(c, a.awsClient, meta.ActualModelName) - } - return -} - -func (a *Adaptor) GetModelList() (models []string) { - for n := range awsModelIDMap { - models = append(models, n) - } - return -} - -func (a *Adaptor) GetChannelName() string { - return "aws" -} diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go new file mode 100644 index 00000000..eb3c9fb8 --- /dev/null +++ b/relay/adaptor/aws/claude/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, claudeReq) + return claudeReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/claude/main.go similarity index 86% rename from relay/adaptor/aws/main.go rename to relay/adaptor/aws/claude/main.go index 2bb939bf..7142e46f 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -19,21 +19,13 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" "github.com/songquanpeng/one-api/relay/adaptor/openai" relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func wrapErr(err error) *relaymodel.ErrorWithStatusCode { - return &relaymodel.ErrorWithStatusCode{ - StatusCode: http.StatusInternalServerError, - Error: relaymodel.Error{ - Message: fmt.Sprintf("%s", err.Error()), - }, - } -} - // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -var awsModelIDMap = map[string]string{ +var AwsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", "claude-2.1": "anthropic.claude-v2:1", @@ -44,7 +36,7 @@ var awsModelIDMap = map[string]string{ } func awsModelID(requestModel string) (string, error) { - if awsModelID, ok := awsModelIDMap[requestModel]; ok { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { return awsModelID, nil } @@ -54,7 +46,7 @@ func awsModelID(requestModel string) (string, error) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } awsReq := &bedrockruntime.InvokeModelInput{ @@ -65,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { - return wrapErr(errors.New("request not found")), nil + return utils.WrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*anthropic.Request) awsClaudeReq := &Request{ AnthropicVersion: "bedrock-2023-05-31", } if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil + return utils.WrapErr(errors.Wrap(err, "copy request")), nil } awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil } awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModel")), nil + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil } claudeResponse := new(anthropic.Response) err = json.Unmarshal(awsResp.Body, claudeResponse) if err != nil { - return wrapErr(errors.Wrap(err, "unmarshal response")), nil + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil } openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) @@ -108,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E createdTime := helper.GetTimestamp() awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ @@ -119,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { - return wrapErr(errors.New("request not found")), nil + return utils.WrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*anthropic.Request) @@ -127,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E AnthropicVersion: "bedrock-2023-05-31", } if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil + return utils.WrapErr(errors.Wrap(err, "copy request")), nil } awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil } awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil } stream := awsResp.GetStream() defer stream.Close() diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/claude/model.go similarity index 100% rename from relay/adaptor/aws/model.go rename to relay/adaptor/aws/claude/model.go diff --git a/relay/adaptor/aws/llama3/adapter.go b/relay/adaptor/aws/llama3/adapter.go new file mode 100644 index 00000000..83edbc9d --- /dev/null +++ b/relay/adaptor/aws/llama3/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/songquanpeng/one-api/common/ctxkey" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + llamaReq := ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, llamaReq) + return llamaReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go new file mode 100644 index 00000000..e5fcd89f --- /dev/null +++ b/relay/adaptor/aws/llama3/main.go @@ -0,0 +1,231 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "text/template" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/random" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// Only support llama-3-8b and llama-3-70b instruction models +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var AwsModelIDMap = map[string]string{ + "llama3-8b-8192": "meta.llama3-8b-instruct-v1:0", + "llama3-70b-8192": "meta.llama3-70b-instruct-v1:0", +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +// promptTemplate with range +const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|> +` + +var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate)) + +func RenderPrompt(messages []relaymodel.Message) string { + var buf bytes.Buffer + err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages}) + if err != nil { + logger.SysError("error rendering prompt messages: " + err.Error()) + } + return buf.String() +} + +func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { + llamaRequest := Request{ + MaxGenLen: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + } + if llamaRequest.MaxGenLen == 0 { + llamaRequest.MaxGenLen = 2048 + } + prompt := RenderPrompt(textRequest.Messages) + llamaRequest.Prompt = prompt + return &llamaRequest +} + +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + var llamaResponse Response + err = json.Unmarshal(awsResp.Body, &llamaResponse) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := ResponseLlama2OpenAI(&llamaResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: llamaResponse.PromptTokenCount, + CompletionTokens: llamaResponse.GenerationTokenCount, + TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { + var responseText string + if len(llamaResponse.Generation) > 0 { + responseText = llamaResponse.Generation + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: relaymodel.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: llamaResponse.StopReason, + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + var llamaResp StreamResponse + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + if llamaResp.PromptTokenCount > 0 { + usage.PromptTokens = llamaResp.PromptTokenCount + } + if llamaResp.StopReason == "stop" { + usage.CompletionTokens = llamaResp.GenerationTokenCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + response := StreamResponseLlama2OpenAI(&llamaResp) + response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) + response.Model = c.GetString(ctxkey.OriginalModel) + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} + +func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = llamaResponse.Generation + choice.Delta.Role = "assistant" + finishReason := llamaResponse.StopReason + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse +} diff --git a/relay/adaptor/aws/llama3/main_test.go b/relay/adaptor/aws/llama3/main_test.go new file mode 100644 index 00000000..d539eee8 --- /dev/null +++ b/relay/adaptor/aws/llama3/main_test.go @@ -0,0 +1,45 @@ +package aws_test + +import ( + "testing" + + aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/stretchr/testify/assert" +) + +func TestRenderPrompt(t *testing.T) { + messages := []relaymodel.Message{ + { + Role: "user", + Content: "What's your name?", + }, + } + prompt := aws.RenderPrompt(messages) + expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) + + messages = []relaymodel.Message{ + { + Role: "system", + Content: "Your name is Kat. You are a detective.", + }, + { + Role: "user", + Content: "What's your name?", + }, + { + Role: "assistant", + Content: "Kat", + }, + { + Role: "user", + Content: "What's your job?", + }, + } + prompt = aws.RenderPrompt(messages) + expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) +} diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go new file mode 100644 index 00000000..7b86c3b8 --- /dev/null +++ b/relay/adaptor/aws/llama3/model.go @@ -0,0 +1,29 @@ +package aws + +// Request is the request to AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Request struct { + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` +} + +// Response is the response from AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Response struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} + +// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None} +type StreamResponse struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} diff --git a/relay/adaptor/aws/registry.go b/relay/adaptor/aws/registry.go new file mode 100644 index 00000000..5f655480 --- /dev/null +++ b/relay/adaptor/aws/registry.go @@ -0,0 +1,39 @@ +package aws + +import ( + claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude" + llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" +) + +type AwsModelType int + +const ( + AwsClaude AwsModelType = iota + 1 + AwsLlama3 +) + +var ( + adaptors = map[string]AwsModelType{} +) + +func init() { + for model := range claude.AwsModelIDMap { + adaptors[model] = AwsClaude + } + for model := range llama3.AwsModelIDMap { + adaptors[model] = AwsLlama3 + } +} + +func GetAdaptor(model string) utils.AwsAdapter { + adaptorType := adaptors[model] + switch adaptorType { + case AwsClaude: + return &claude.Adaptor{} + case AwsLlama3: + return &llama3.Adaptor{} + default: + return nil + } +} diff --git a/relay/adaptor/aws/utils/adaptor.go b/relay/adaptor/aws/utils/adaptor.go new file mode 100644 index 00000000..4cb880f2 --- /dev/null +++ b/relay/adaptor/aws/utils/adaptor.go @@ -0,0 +1,51 @@ +package utils + +import ( + "errors" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type AwsAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} + +type Adaptor struct { + Meta *meta.Meta + AwsClient *bedrockruntime.Client +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} diff --git a/relay/adaptor/aws/utils/utils.go b/relay/adaptor/aws/utils/utils.go new file mode 100644 index 00000000..669dc628 --- /dev/null +++ b/relay/adaptor/aws/utils/utils.go @@ -0,0 +1,16 @@ +package utils + +import ( + "net/http" + + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func WrapErr(err error) *relaymodel.ErrorWithStatusCode { + return &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: err.Error(), + }, + } +} diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 6ff6b0d3..be2fb4ab 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { @@ -28,7 +29,14 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil + switch meta.Mode { + case relaymode.ChatCompletions: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", meta.BaseURL, meta.Config.UserID), nil + case relaymode.Embeddings: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", meta.BaseURL, meta.Config.UserID), nil + default: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { @@ -41,7 +49,14 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + switch relayMode { + case relaymode.Completions: + return ConvertCompletionsRequest(*request), nil + case relaymode.ChatCompletions, relaymode.Embeddings: + return request, nil + default: + return nil, errors.New("not implemented") + } } func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go index c76520a2..980a2891 100644 --- a/relay/adaptor/cloudflare/main.go +++ b/relay/adaptor/cloudflare/main.go @@ -3,11 +3,13 @@ package cloudflare import ( "bufio" "encoding/json" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/render" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -16,57 +18,23 @@ import ( "github.com/songquanpeng/one-api/relay/model" ) -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { - var promptBuilder strings.Builder - for _, message := range textRequest.Messages { - promptBuilder.WriteString(message.StringContent()) - promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息 - } - +func ConvertCompletionsRequest(textRequest model.GeneralOpenAIRequest) *Request { + p, _ := textRequest.Prompt.(string) return &Request{ + Prompt: p, MaxTokens: textRequest.MaxTokens, - Prompt: promptBuilder.String(), Stream: textRequest.Stream, Temperature: textRequest.Temperature, } } -func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: cloudflareResponse.Result.Response, - }, - FinishReason: "stop", - } - fullTextResponse := openai.TextResponse{ - Object: "chat.completion", - Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, - } - return &fullTextResponse -} - -func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = cloudflareResponse.Response - choice.Delta.Role = "assistant" - openaiResponse := openai.ChatCompletionsStreamResponse{ - Object: "chat.completion.chunk", - Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, - Created: helper.GetTimestamp(), - } - return &openaiResponse -} - func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) common.SetEventStreamHeaders(c) id := helper.GetResponseID(c) - responseModel := c.GetString("original_model") + responseModel := c.GetString(ctxkey.OriginalModel) var responseText string for scanner.Scan() { @@ -77,22 +45,22 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN data = strings.TrimPrefix(data, "data: ") data = strings.TrimSuffix(data, "\r") - var cloudflareResponse StreamResponse - err := json.Unmarshal([]byte(data), &cloudflareResponse) + if data == "[DONE]" { + break + } + + var response openai.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) continue } - - response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) - if response == nil { - continue + for _, v := range response.Choices { + v.Delta.Role = "assistant" + responseText += v.Delta.StringContent() } - - responseText += cloudflareResponse.Response response.Id = id - response.Model = responseModel - + response.Model = modelName err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) @@ -123,22 +91,25 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - var cloudflareResponse Response - err = json.Unmarshal(responseBody, &cloudflareResponse) + var response openai.TextResponse + err = json.Unmarshal(responseBody, &response) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) - fullTextResponse.Model = modelName - usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) - fullTextResponse.Usage = *usage - fullTextResponse.Id = helper.GetResponseID(c) - jsonResponse, err := json.Marshal(fullTextResponse) + response.Model = modelName + var responseText string + for _, v := range response.Choices { + responseText += v.Message.Content.(string) + } + usage := openai.ResponseText2Usage(responseText, modelName, promptTokens) + response.Usage = *usage + response.Id = helper.GetResponseID(c) + jsonResponse, err := json.Marshal(response) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + _, _ = c.Writer.Write(jsonResponse) return nil, usage } diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0664ecd1..0d3bafe0 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -1,25 +1,13 @@ package cloudflare +import "github.com/songquanpeng/one-api/relay/model" + type Request struct { - Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Prompt string `json:"prompt,omitempty"` - Raw bool `json:"raw,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` -} - -type Result struct { - Response string `json:"response"` -} - -type Response struct { - Result Result `json:"result"` - Success bool `json:"success"` - Errors []string `json:"errors"` - Messages []string `json:"messages"` -} - -type StreamResponse struct { - Response string `json:"response"` + Messages []model.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` } diff --git a/relay/adaptor/novita/constants.go b/relay/adaptor/novita/constants.go new file mode 100644 index 00000000..c6618308 --- /dev/null +++ b/relay/adaptor/novita/constants.go @@ -0,0 +1,19 @@ +package novita + +// https://novita.ai/llm-api + +var ModelList = []string{ + "meta-llama/llama-3-8b-instruct", + "meta-llama/llama-3-70b-instruct", + "nousresearch/hermes-2-pro-llama-3-8b", + "nousresearch/nous-hermes-llama2-13b", + "mistralai/mistral-7b-instruct", + "cognitivecomputations/dolphin-mixtral-8x22b", + "sao10k/l3-70b-euryale-v2.1", + "sophosympatheia/midnight-rose-70b", + "gryphe/mythomax-l2-13b", + "Nous-Hermes-2-Mixtral-8x7B-DPO", + "lzlv_70b", + "teknium/openhermes-2.5-mistral-7b", + "microsoft/wizardlm-2-8x22b", +} diff --git a/relay/adaptor/novita/main.go b/relay/adaptor/novita/main.go new file mode 100644 index 00000000..80efa412 --- /dev/null +++ b/relay/adaptor/novita/main.go @@ -0,0 +1,15 @@ +package novita + +import ( + "fmt" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/chat/completions", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for novita", meta.Mode) +} diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index caef6976..f507630b 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -2,18 +2,20 @@ package openai import ( "fmt" + "io" + "net/http" + "strings" + "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/novita" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) type Adaptor struct { @@ -48,6 +50,8 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return minimax.GetRequestURL(meta) case channeltype.Doubao: return doubao.GetRequestURL(meta) + case channeltype.Novita: + return novita.GetRequestURL(meta) default: return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go index 5d5b4008..3445249c 100644 --- a/relay/adaptor/openai/compatible.go +++ b/relay/adaptor/openai/compatible.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/moonshot" + "github.com/songquanpeng/one-api/relay/adaptor/novita" "github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/togetherai" "github.com/songquanpeng/one-api/relay/channeltype" @@ -28,6 +29,7 @@ var CompatibleChannels = []int{ channeltype.StepFun, channeltype.DeepSeek, channeltype.TogetherAI, + channeltype.Novita, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -56,6 +58,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "together.ai", togetherai.ModelList case channeltype.Doubao: return "doubao", doubao.ModelList + case channeltype.Novita: + return "novita", novita.ModelList default: return "openai", ModelList } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 5f1ae323..8a7d5743 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -2,6 +2,7 @@ package ratio import ( "encoding/json" + "fmt" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -169,6 +170,9 @@ var ModelRatio = map[string]float64{ "step-1v-32k": 0.024 * RMB, "step-1-32k": 0.024 * RMB, "step-1-200k": 0.15 * RMB, + // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ + "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens + "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens // https://cohere.com/pricing "command": 0.5, "command-nightly": 0.5, @@ -185,7 +189,11 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, } -var CompletionRatio = map[string]float64{} +var CompletionRatio = map[string]float64{ + // aws llama3 + "llama3-8b-8192(33)": 0.0006 / 0.0003, + "llama3-70b-8192(33)": 0.0035 / 0.00265, +} var DefaultModelRatio map[string]float64 var DefaultCompletionRatio map[string]float64 @@ -234,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &ModelRatio) } -func GetModelRatio(name string) float64 { +func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } - ratio, ok := ModelRatio[name] - if !ok { - ratio, ok = DefaultModelRatio[name] + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := ModelRatio[model]; ok { + return ratio } - if !ok { - logger.SysError("model ratio not found: " + name) - return 30 + if ratio, ok := DefaultModelRatio[model]; ok { + return ratio } - return ratio + if ratio, ok := ModelRatio[name]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[name]; ok { + return ratio + } + logger.SysError("model ratio not found: " + name) + return 30 } func CompletionRatio2JSONString() string { @@ -265,10 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &CompletionRatio) } -// GetCompletionRatio returns the completion ratio of a model -// -// completion ratio is the ratio comparing to the ratio of prompt -func GetCompletionRatio(name string) float64 { +func GetCompletionRatio(name string, channelType int) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := CompletionRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultCompletionRatio[model]; ok { + return ratio + } if ratio, ok := CompletionRatio[name]; ok { return ratio } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index d8885ae9..d3891c16 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -42,5 +42,6 @@ const ( DeepL TogetherAI Doubao + Novita Dummy ) diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 513d183b..5177333b 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -42,6 +42,7 @@ var ChannelBaseURLs = []string{ "https://api-free.deepl.com", // 38 "https://api.together.xyz", // 39 "https://ark.cn-beijing.volces.com", // 40 + "https://api.novita.ai/v3/openai", // 41 } func init() { diff --git a/relay/controller/audio.go b/relay/controller/audio.go index dd44c09f..e1e88978 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -54,10 +54,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := billingratio.GetModelRatio(audioModel) + modelRatio := billingratio.GetModelRatio(audioModel, channelType) // groupRatio := billingratio.GetGroupRatio(group) - groupRatio := c.GetFloat64(ctxkey.ChannelRatio) // get minimal ratio from multiple groups - + groupRatio := c.GetFloat64(ctxkey.ChannelRatio) ratio := modelRatio * groupRatio var quota int64 var preConsumedQuota int64 diff --git a/relay/controller/helper.go b/relay/controller/helper.go index b2016030..cac11525 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -94,7 +94,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M logger.Error(ctx, "usage is nil, which is unexpected") return } - completionRatio := billingratio.GetCompletionRatio(textRequest.Model) + + completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) diff --git a/relay/controller/image.go b/relay/controller/image.go index 2c6900a5..d1d33f8b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -5,10 +5,6 @@ import ( "context" "encoding/json" "fmt" - "io" - "net/http" - "strings" - "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" @@ -21,6 +17,9 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" ) func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { @@ -169,9 +168,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billingratio.GetModelRatio(imageModel) + modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) // groupRatio := billingratio.GetGroupRatio(meta.Group) - groupRatio := c.GetFloat64(ctxkey.ChannelRatio) // pre-selected cheapest channel ratio + groupRatio := c.GetFloat64(ctxkey.ChannelRatio) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) diff --git a/relay/controller/text.go b/relay/controller/text.go index 0952b63f..bff2273b 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -38,9 +38,9 @@ func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := billingratio.GetModelRatio(textRequest.Model) + modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) // groupRatio := billingratio.GetGroupRatio(meta.Group) - groupRatio := meta.ChannelRatio + groupRatio := c.GetFloat64(ctxkey.ChannelRatio) ratio := modelRatio * groupRatio // pre-consume quota diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js index 45a5b776..ef4acf14 100644 --- a/web/air/src/components/PersonalSetting.js +++ b/web/air/src/components/PersonalSetting.js @@ -47,7 +47,7 @@ const PersonalSetting = () => { const [countdown, setCountdown] = useState(30); const [affLink, setAffLink] = useState(''); const [systemToken, setSystemToken] = useState(''); - // const [models, setModels] = useState([]); + const [models, setModels] = useState([]); const [openTransfer, setOpenTransfer] = useState(false); const [transferAmount, setTransferAmount] = useState(0); @@ -72,7 +72,7 @@ const PersonalSetting = () => { console.log(userState); } ); - // loadModels().then(); + loadModels().then(); getAffLink().then(); setTransferAmount(getQuotaPerUnit()); }, []); @@ -127,16 +127,16 @@ const PersonalSetting = () => { } }; - // const loadModels = async () => { - // let res = await API.get(`/api/user/models`); - // const { success, message, data } = res.data; - // if (success) { - // setModels(data); - // console.log(data); - // } else { - // showError(message); - // } - // }; + const loadModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + setModels(data); + console.log(data); + } else { + showError(message); + } + }; const handleAffLinkClick = async (e) => { e.target.select(); @@ -344,7 +344,7 @@ const PersonalSetting = () => { } > 调用信息 - {/* 可用模型 +

可用模型(可点击复制)

{models.map((model) => ( @@ -355,7 +355,7 @@ const PersonalSetting = () => { ))} -
*/} + {/* { const res = await API.get(`/api/channel/?p=${startIdx}`); const { success, message, data } = res.data; if (success) { - if (startIdx === 0) { - setChannels(data); - } else { - let newChannels = [...channels]; - newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); - setChannels(newChannels); - } + let localChannels = data.map((channel) => { + if (channel.models === '') { + channel.models = []; + channel.test_model = ""; + } else { + channel.models = channel.models.split(','); + if (channel.models.length > 0) { + channel.test_model = channel.models[0]; + } + channel.model_options = channel.models.map((model) => { + return { + key: model, + text: model, + value: model, + } + }) + console.log('channel', channel) + } + return channel; + }); + if (startIdx === 0) { + setChannels(localChannels); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); + setChannels(newChannels); + } } else { showError(message); } @@ -225,19 +245,31 @@ const ChannelsTable = () => { setSearching(false); }; - const testChannel = async (id, name, idx) => { - const res = await API.get(`/api/channel/test/${id}/`); - const { success, message, time } = res.data; + const switchTestModel = async (idx, model) => { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].test_model = model; + setChannels(newChannels); + }; + + const testChannel = async (id, name, idx, m) => { + const res = await API.get(`/api/channel/test/${id}?model=${m}`); + const { success, message, time, model } = res.data; if (success) { let newChannels = [...channels]; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].test_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); } + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].response_time = time * 1000; + newChannels[realIdx].test_time = Date.now() / 1000; + setChannels(newChannels); }; const testChannels = async (scope) => { @@ -405,6 +437,7 @@ const ChannelsTable = () => { > 优先级 + 测试模型 操作 @@ -459,13 +492,24 @@ const ChannelsTable = () => { basic /> + + { + switchTestModel(idx, data.value); + }} + /> +