From 274fcf3d76299e1e56a670a4c203e130d3561a0e Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Wed, 3 Jul 2024 20:50:40 +0800 Subject: [PATCH 01/11] refactor: init db (#1590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 江杭辉 --- main.go | 22 ++--- model/main.go | 219 ++++++++++++++++++++++++++++++++------------------ 2 files changed, 150 insertions(+), 91 deletions(-) diff --git a/main.go b/main.go index 4afbe5dd..67a3cd95 100644 --- a/main.go +++ b/main.go @@ -27,27 +27,19 @@ func main() { common.Init() logger.SetupLogger() logger.SysLogf("One API %s started", common.Version) - if os.Getenv("GIN_MODE") != "debug" { + + if os.Getenv("GIN_MODE") != gin.DebugMode { gin.SetMode(gin.ReleaseMode) } if config.DebugEnabled { logger.SysLog("running in debug mode") } - var err error + // Initialize SQL Database - model.DB, err = model.InitDB("SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize database: " + err.Error()) - } - if os.Getenv("LOG_SQL_DSN") != "" { - logger.SysLog("using secondary database for table logs") - model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") - if err != nil { - logger.FatalLog("failed to initialize secondary database: " + err.Error()) - } - } else { - model.LOG_DB = model.DB - } + model.InitDB() + model.InitLogDB() + + var err error err = model.CreateRootAccountIfNeed() if err != nil { logger.FatalLog("database init error: " + err.Error()) diff --git a/model/main.go b/model/main.go index 4b5323c4..11752404 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "database/sql" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -60,90 +61,156 @@ func CreateRootAccountIfNeed() error { } func chooseDB(envName string) (*gorm.DB, error) { - if os.Getenv(envName) != "" { - dsn := os.Getenv(envName) - if strings.HasPrefix(dsn, "postgres://") { - // Use PostgreSQL - logger.SysLog("using PostgreSQL as database") - common.UsingPostgreSQL = true - return gorm.Open(postgres.New(postgres.Config{ - DSN: dsn, - PreferSimpleProtocol: true, // disables implicit prepared statement usage - }), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) - } + dsn := os.Getenv(envName) + + switch { + case strings.HasPrefix(dsn, "postgres://"): + // Use PostgreSQL + return openPostgreSQL(dsn) + case dsn != "": // Use MySQL - logger.SysLog("using MySQL as database") - common.UsingMySQL = true - return gorm.Open(mysql.Open(dsn), &gorm.Config{ - PrepareStmt: true, // precompile SQL - }) + return openMySQL(dsn) + default: + // Use SQLite + return openSQLite() } - // Use SQLite - logger.SysLog("SQL_DSN not set, using SQLite as database") - common.UsingSQLite = true - config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) - return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ +} + +func openPostgreSQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true + return gorm.Open(postgres.New(postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, // disables implicit prepared statement usage + }), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } -func InitDB(envName string) (db *gorm.DB, err error) { - db, err = chooseDB(envName) - if err == nil { - if config.DebugSQLEnabled { - db = db.Debug() - } - sqlDB, err := db.DB() - if err != nil { - return nil, err - } - sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) +func openMySQL(dsn string) (*gorm.DB, error) { + logger.SysLog("using MySQL as database") + common.UsingMySQL = true + return gorm.Open(mysql.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} - if !config.IsMasterNode { - return db, err - } - if common.UsingMySQL { - _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded - } - logger.SysLog("database migration started") - err = db.AutoMigrate(&Channel{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Token{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&User{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Option{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Redemption{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Ability{}) - if err != nil { - return nil, err - } - err = db.AutoMigrate(&Log{}) - if err != nil { - return nil, err - } - logger.SysLog("database migrated") - return db, err - } else { - logger.FatalLog(err) +func openSQLite() (*gorm.DB, error) { + logger.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + dsn := fmt.Sprintf("%s?_busy_timeout=%d", common.SQLitePath, common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func InitDB() { + var err error + DB, err = chooseDB("SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize database: " + err.Error()) + return } - return db, err + + sqlDB := setDBConns(DB) + + if !config.IsMasterNode { + return + } + + if common.UsingMySQL { + _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + } + + logger.SysLog("database migration started") + if err = migrateDB(); err != nil { + logger.FatalLog("failed to migrate database: " + err.Error()) + return + } + logger.SysLog("database migrated") +} + +func migrateDB() error { + var err error + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Token{}); err != nil { + return err + } + if err = DB.AutoMigrate(&User{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Option{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Redemption{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Ability{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Log{}); err != nil { + return err + } + if err = DB.AutoMigrate(&Channel{}); err != nil { + return err + } + return nil +} + +func InitLogDB() { + if os.Getenv("LOG_SQL_DSN") == "" { + LOG_DB = DB + return + } + + logger.SysLog("using secondary database for table logs") + var err error + LOG_DB, err = chooseDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + return + } + + setDBConns(LOG_DB) + + if !config.IsMasterNode { + return + } + + logger.SysLog("secondary database migration started") + err = migrateLOGDB() + if err != nil { + logger.FatalLog("failed to migrate secondary database: " + err.Error()) + return + } + logger.SysLog("secondary database migrated") +} + +func migrateLOGDB() error { + var err error + if err = LOG_DB.AutoMigrate(&Log{}); err != nil { + return err + } + return nil +} + +func setDBConns(db *gorm.DB) *sql.DB { + if config.DebugSQLEnabled { + db = db.Debug() + } + + sqlDB, err := db.DB() + if err != nil { + logger.FatalLog("failed to connect database: " + err.Error()) + return nil + } + + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) + return sqlDB } func closeDB(db *gorm.DB) error { From c4fe57c16512372b84f0765c78d3e0b2d1eef912 Mon Sep 17 00:00:00 2001 From: LinZeliang Date: Wed, 3 Jul 2024 20:53:29 +0800 Subject: [PATCH 02/11] feat: support one or more log file (#1400) Co-authored-by: Laisky.Cai --- common/config/config.go | 3 +++ common/logger/logger.go | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/common/config/config.go b/common/config/config.go index 4f1c25b6..3f321c87 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -145,6 +145,9 @@ var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var GeminiVersion = env.String("GEMINI_VERSION", "v1") + +var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) + var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) diff --git a/common/logger/logger.go b/common/logger/logger.go index f725c619..d1022932 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -27,7 +27,12 @@ var setupLogOnce sync.Once func SetupLogger() { setupLogOnce.Do(func() { if LogDir != "" { - logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + var logPath string + if config.OnlyOneLogFile { + logPath = filepath.Join(LogDir, "oneapi.log") + } else { + logPath = filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + } fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") From ec6ad248104045d7b67effc72867d9f4a31e55fe Mon Sep 17 00:00:00 2001 From: Leo Q Date: Wed, 3 Jul 2024 22:23:49 +0800 Subject: [PATCH 03/11] feat: support smtp without auth (#1101) --- common/message/email.go | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/common/message/email.go b/common/message/email.go index b06782db..187ac8c3 100644 --- a/common/message/email.go +++ b/common/message/email.go @@ -6,11 +6,16 @@ import ( "encoding/base64" "fmt" "github.com/songquanpeng/one-api/common/config" + "net" "net/smtp" "strings" "time" ) +func shouldAuth() bool { + return config.SMTPAccount != "" || config.SMTPToken != "" +} + func SendEmail(subject string, receiver string, content string) error { if receiver == "" { return fmt.Errorf("receiver is empty") @@ -41,16 +46,24 @@ func SendEmail(subject string, receiver string, content string) error { "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if config.SMTPPort == 465 { - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - ServerName: config.SMTPServer, + if config.SMTPPort == 465 || !shouldAuth() { + // need advanced client + var conn net.Conn + var err error + if config.SMTPPort == 465 { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: config.SMTPServer, + } + conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) + } else { + conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)) } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } @@ -59,8 +72,10 @@ func SendEmail(subject string, receiver string, content string) error { return err } defer client.Close() - if err = client.Auth(auth); err != nil { - return err + if shouldAuth() { + if err = client.Auth(auth); err != nil { + return err + } } if err = client.Mail(config.SMTPFrom); err != nil { return err From 273be557975b758c4e6ee36165daeab772895b58 Mon Sep 17 00:00:00 2001 From: Leo Q Date: Thu, 4 Jul 2024 08:35:41 +0800 Subject: [PATCH 04/11] feat(ui): show available models for air theme (#1595) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(ui): air 主题显示可用模型 * chore: 改为全角括号 --- web/air/src/components/PersonalSetting.js | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/web/air/src/components/PersonalSetting.js b/web/air/src/components/PersonalSetting.js index 45a5b776..ef4acf14 100644 --- a/web/air/src/components/PersonalSetting.js +++ b/web/air/src/components/PersonalSetting.js @@ -47,7 +47,7 @@ const PersonalSetting = () => { const [countdown, setCountdown] = useState(30); const [affLink, setAffLink] = useState(''); const [systemToken, setSystemToken] = useState(''); - // const [models, setModels] = useState([]); + const [models, setModels] = useState([]); const [openTransfer, setOpenTransfer] = useState(false); const [transferAmount, setTransferAmount] = useState(0); @@ -72,7 +72,7 @@ const PersonalSetting = () => { console.log(userState); } ); - // loadModels().then(); + loadModels().then(); getAffLink().then(); setTransferAmount(getQuotaPerUnit()); }, []); @@ -127,16 +127,16 @@ const PersonalSetting = () => { } }; - // const loadModels = async () => { - // let res = await API.get(`/api/user/models`); - // const { success, message, data } = res.data; - // if (success) { - // setModels(data); - // console.log(data); - // } else { - // showError(message); - // } - // }; + const loadModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + setModels(data); + console.log(data); + } else { + showError(message); + } + }; const handleAffLinkClick = async (e) => { e.target.select(); @@ -344,7 +344,7 @@ const PersonalSetting = () => { } > 调用信息 - {/* 可用模型 +

可用模型(可点击复制)

{models.map((model) => ( @@ -355,7 +355,7 @@ const PersonalSetting = () => { ))} -
*/} + {/* Date: Fri, 5 Jul 2024 18:05:16 +0800 Subject: [PATCH 05/11] feat: support test specific model (#1600) --- controller/channel-test.go | 36 ++++++----- web/default/src/components/ChannelsTable.js | 70 +++++++++++++++++---- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index b8c41819..f8327284 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -27,15 +28,15 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - - "github.com/gin-gonic/gin" ) -func buildTestRequest() *relaymodel.GeneralOpenAIRequest { +func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { + if model == "" { + model = "gpt-3.5-turbo" + } testRequest := &relaymodel.GeneralOpenAIRequest{ MaxTokens: 2, - Stream: false, - Model: "gpt-3.5-turbo", + Model: model, } testMessage := relaymodel.Message{ Role: "user", @@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest { return testRequest } -func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { +func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) - var modelName string - modelList := adaptor.GetModelList() + modelName := request.Model modelMap := channel.GetModelMapping() - if len(modelList) != 0 { - modelName = modelList[0] - } if modelName == "" || !strings.Contains(channel.Models, modelName) { modelNames := strings.Split(channel.Models, ",") if len(modelNames) > 0 { @@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error modelName = modelMap[modelName] } } - request := buildTestRequest() + meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName - meta.OriginModelName, meta.ActualModelName = modelName, modelName convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil @@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) { }) return } + model := c.Query("model") + testRequest := buildTestRequest(model) tik := time.Now() - err, _ = testChannel(channel) + err, _ = testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() + if err != nil { + milliseconds = 0 + } go channel.UpdateResponseTime(milliseconds) consumedTime := float64(milliseconds) / 1000.0 if err != nil { @@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) { "success": false, "message": err.Error(), "time": consumedTime, + "model": model, }) return } @@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) { "success": true, "message": "", "time": consumedTime, + "model": model, }) return } @@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error { for _, channel := range channels { isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel) + testRequest := buildTestRequest("") + err, openaiErr := testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) if config.AutomaticDisableChannelEnabled { monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } else { diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js index 1258ca5a..6025b7d9 100644 --- a/web/default/src/components/ChannelsTable.js +++ b/web/default/src/components/ChannelsTable.js @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; import { API, @@ -70,13 +70,33 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/?p=${startIdx}`); const { success, message, data } = res.data; if (success) { - if (startIdx === 0) { - setChannels(data); - } else { - let newChannels = [...channels]; - newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); - setChannels(newChannels); - } + let localChannels = data.map((channel) => { + if (channel.models === '') { + channel.models = []; + channel.test_model = ""; + } else { + channel.models = channel.models.split(','); + if (channel.models.length > 0) { + channel.test_model = channel.models[0]; + } + channel.model_options = channel.models.map((model) => { + return { + key: model, + text: model, + value: model, + } + }) + console.log('channel', channel) + } + return channel; + }); + if (startIdx === 0) { + setChannels(localChannels); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); + setChannels(newChannels); + } } else { showError(message); } @@ -225,19 +245,31 @@ const ChannelsTable = () => { setSearching(false); }; - const testChannel = async (id, name, idx) => { - const res = await API.get(`/api/channel/test/${id}/`); - const { success, message, time } = res.data; + const switchTestModel = async (idx, model) => { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].test_model = model; + setChannels(newChannels); + }; + + const testChannel = async (id, name, idx, m) => { + const res = await API.get(`/api/channel/test/${id}?model=${m}`); + const { success, message, time, model } = res.data; if (success) { let newChannels = [...channels]; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].test_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); } + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].response_time = time * 1000; + newChannels[realIdx].test_time = Date.now() / 1000; + setChannels(newChannels); }; const testChannels = async (scope) => { @@ -405,6 +437,7 @@ const ChannelsTable = () => { > 优先级 + 测试模型 操作 @@ -459,13 +492,24 @@ const ChannelsTable = () => { basic /> + + { + switchTestModel(idx, data.value); + }} + /> +