diff --git a/controller/misc.go b/controller/misc.go index 0a4f1d8..0259426 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -11,6 +11,22 @@ import ( "github.com/gin-gonic/gin" ) +func TestStatus(c *gin.Context) { + err := model.PingDB() + if err != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "success": false, + "message": "数据库连接失败", + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Server is running", + }) + return +} + func GetStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/model/main.go b/model/main.go index ade3cc1..bb91255 100644 --- a/model/main.go +++ b/model/main.go @@ -5,9 +5,11 @@ import ( "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" + "log" "one-api/common" "os" "strings" + "sync" "time" ) @@ -148,3 +150,33 @@ func CloseDB() error { err = sqlDB.Close() return err } + +var ( + lastPingTime time.Time + pingMutex sync.Mutex +) + +func PingDB() error { + pingMutex.Lock() + defer pingMutex.Unlock() + + if time.Since(lastPingTime) < time.Second*10 { + return nil + } + + sqlDB, err := DB.DB() + if err != nil { + log.Printf("Error getting sql.DB from GORM: %v", err) + return err + } + + err = sqlDB.Ping() + if err != nil { + log.Printf("Error pinging DB: %v", err) + return err + } + + lastPingTime = time.Now() + common.SysLog("Database pinged successfully") + return nil +} diff --git a/router/api-router.go b/router/api-router.go index 1683a4f..592e8ed 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.Use(middleware.GlobalAPIRateLimit()) { apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/midjourney", controller.GetMidjourney)