mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-21 02:24:32 +08:00
merge v4.2.5
This commit is contained in:
@@ -34,6 +34,50 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuthConfig 定义授权配置
|
||||
type AuthConfig struct {
|
||||
ExactPaths map[string]bool // 精确匹配的路径
|
||||
PrefixPaths map[string]bool // 前缀匹配的路径
|
||||
}
|
||||
|
||||
var authConfig = &AuthConfig{
|
||||
ExactPaths: map[string]bool{
|
||||
"/api/user/login": false,
|
||||
"/api/user/logout": false,
|
||||
"/api/user/resetPass": false,
|
||||
"/api/user/register": false,
|
||||
"/api/admin/login": false,
|
||||
"/api/admin/logout": false,
|
||||
"/api/admin/login/captcha": false,
|
||||
"/api/app/list": false,
|
||||
"/api/app/type/list": false,
|
||||
"/api/app/list/user": false,
|
||||
"/api/model/list": false,
|
||||
"/api/mj/imgWall": false,
|
||||
"/api/mj/notify": false,
|
||||
"/api/invite/hits": false,
|
||||
"/api/sd/imgWall": false,
|
||||
"/api/dall/imgWall": false,
|
||||
"/api/product/list": false,
|
||||
"/api/menu/list": false,
|
||||
"/api/markMap/client": false,
|
||||
"/api/payment/doPay": false,
|
||||
"/api/payment/payWays": false,
|
||||
"/api/download": false,
|
||||
"/api/dall/models": false,
|
||||
},
|
||||
PrefixPaths: map[string]bool{
|
||||
"/api/test/": false,
|
||||
"/api/payment/notify/": false,
|
||||
"/api/user/clogin": false,
|
||||
"/api/config/": false,
|
||||
"/api/function/": false,
|
||||
"/api/sms/": false,
|
||||
"/api/captcha/": false,
|
||||
"/static/": false,
|
||||
},
|
||||
}
|
||||
|
||||
type AppServer struct {
|
||||
Config *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
@@ -61,13 +105,28 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
|
||||
}
|
||||
|
||||
func (s *AppServer) Run(db *gorm.DB) error {
|
||||
|
||||
// 重命名 config 表字段
|
||||
if db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||
db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||
}
|
||||
if db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||
db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||
}
|
||||
if db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||
db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||
}
|
||||
if db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||
db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||
}
|
||||
|
||||
// load system configs
|
||||
var sysConfig model.Config
|
||||
err := db.Where("marker", "system").First(&sysConfig).Error
|
||||
err := db.Where("name", "system").First(&sysConfig).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load system config: %v", err)
|
||||
}
|
||||
err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
|
||||
err = utils.JsonDecode(sysConfig.Value, &s.SysConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode system config: %v", err)
|
||||
}
|
||||
@@ -99,6 +158,7 @@ func (s *AppServer) Run(db *gorm.DB) error {
|
||||
&model.MidJourneyJob{},
|
||||
&model.UserLoginLog{},
|
||||
&model.DallJob{},
|
||||
&model.JimengJob{},
|
||||
)
|
||||
// 手动删除字段
|
||||
if db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||
@@ -216,6 +276,11 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
// 用户授权验证
|
||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !needLogin(c) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
var tokenString string
|
||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||
@@ -234,18 +299,13 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
if tokenString == "" {
|
||||
if needLogin(c) {
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
} else { // 直接放行
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
resp.NotAuth(c, "You should put Authorization in request headers")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
if isAdminApi {
|
||||
@@ -256,21 +316,21 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
|
||||
})
|
||||
|
||||
if err != nil && needLogin(c) {
|
||||
if err != nil {
|
||||
resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid && needLogin(c) {
|
||||
if !ok || !token.Valid {
|
||||
resp.NotAuth(c, "Token is invalid")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
|
||||
if expr > 0 && int64(expr) < time.Now().Unix() {
|
||||
resp.NotAuth(c, "Token is expired")
|
||||
c.Abort()
|
||||
return
|
||||
@@ -280,57 +340,48 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
if isAdminApi {
|
||||
key = fmt.Sprintf("admin/%v", claims["user_id"])
|
||||
}
|
||||
if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
|
||||
if _, err := client.Get(context.Background(), key).Result(); err != nil {
|
||||
resp.NotAuth(c, "Token is not found in redis")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set(types.LoginUserID, claims["user_id"])
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func needLogin(c *gin.Context) bool {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/user/logout" ||
|
||||
c.Request.URL.Path == "/api/user/resetPass" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/admin/logout" ||
|
||||
c.Request.URL.Path == "/api/admin/login/captcha" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
c.Request.URL.Path == "/api/chat/history" ||
|
||||
c.Request.URL.Path == "/api/chat/detail" ||
|
||||
c.Request.URL.Path == "/api/chat/list" ||
|
||||
c.Request.URL.Path == "/api/app/list" ||
|
||||
c.Request.URL.Path == "/api/app/type/list" ||
|
||||
c.Request.URL.Path == "/api/app/list/user" ||
|
||||
c.Request.URL.Path == "/api/model/list" ||
|
||||
c.Request.URL.Path == "/api/mj/imgWall" ||
|
||||
c.Request.URL.Path == "/api/mj/notify" ||
|
||||
c.Request.URL.Path == "/api/invite/hits" ||
|
||||
c.Request.URL.Path == "/api/sd/imgWall" ||
|
||||
c.Request.URL.Path == "/api/dall/imgWall" ||
|
||||
c.Request.URL.Path == "/api/product/list" ||
|
||||
c.Request.URL.Path == "/api/menu/list" ||
|
||||
c.Request.URL.Path == "/api/markMap/client" ||
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
c.Request.URL.Path == "/api/download" ||
|
||||
c.Request.URL.Path == "/api/dall/models" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/static/") {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 如果不是 API 路径,不需要登录
|
||||
if !strings.HasPrefix(path, "/api") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查精确匹配的路径
|
||||
if skip, exists := authConfig.ExactPaths[path]; exists {
|
||||
return skip
|
||||
}
|
||||
|
||||
// 检查前缀匹配的路径
|
||||
for prefix, skip := range authConfig.PrefixPaths {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 跳过授权
|
||||
func (s *AppServer) SkipAuth(url string, prefix bool) {
|
||||
if prefix {
|
||||
authConfig.PrefixPaths[url] = false
|
||||
} else {
|
||||
authConfig.ExactPaths[url] = false
|
||||
}
|
||||
}
|
||||
|
||||
// 统一参数处理
|
||||
func parameterHandlerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
@@ -43,9 +43,10 @@ type SmtpConfig struct {
|
||||
}
|
||||
|
||||
type ApiConfig struct {
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
ApiURL string
|
||||
AppId string
|
||||
Token string
|
||||
JimengConfig JimengConfig // 即梦AI配置
|
||||
}
|
||||
|
||||
type AlipayConfig struct {
|
||||
@@ -170,7 +171,7 @@ type SystemConfig struct {
|
||||
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
|
||||
AssistantModelId int `json:"assistant_model_id"` // 用来做提示词,翻译的AI模型 id
|
||||
MaxFileSize int `json:"max_file_size"` // 最大文件大小,单位:MB
|
||||
|
||||
}
|
||||
|
||||
18
api/core/types/jimeng.go
Normal file
18
api/core/types/jimeng.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package types
|
||||
|
||||
// JimengConfig 即梦AI配置
|
||||
type JimengConfig struct {
|
||||
AccessKey string `json:"access_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Power JimengPower `json:"power"`
|
||||
}
|
||||
|
||||
// JimengPower 即梦AI算力配置
|
||||
type JimengPower struct {
|
||||
TextToImage int `json:"text_to_image"`
|
||||
ImageToImage int `json:"image_to_image"`
|
||||
ImageEdit int `json:"image_edit"`
|
||||
ImageEffects int `json:"image_effects"`
|
||||
TextToVideo int `json:"text_to_video"`
|
||||
ImageToVideo int `json:"image_to_video"`
|
||||
}
|
||||
@@ -18,6 +18,7 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23
|
||||
go.uber.org/zap v1.23.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gorm.io/driver/mysql v1.4.7
|
||||
@@ -45,7 +46,7 @@ require (
|
||||
github.com/go-pay/util v0.0.2 // indirect
|
||||
github.com/go-pay/xlog v0.0.2 // indirect
|
||||
github.com/go-pay/xtime v0.0.2 // indirect
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/gorilla/css v1.0.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||
@@ -78,7 +79,7 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/klauspost/compress v1.16.7 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
@@ -120,7 +121,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/fx v1.19.3
|
||||
go.uber.org/multierr v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.7.0 // indirect
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
gorm.io/gorm v1.25.1
|
||||
|
||||
79
api/go.sum
79
api/go.sum
@@ -1,3 +1,5 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I=
|
||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 h1:cKNFQmeCQFN0WNfjScKoVrGi7vXxTVbkCvCqSrOf+P4=
|
||||
@@ -6,6 +8,7 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiw
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||
@@ -13,11 +16,13 @@ github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -28,6 +33,8 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
|
||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
@@ -84,11 +91,27 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
|
||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-tika v0.3.1 h1:l+jr10hDhZjcgxFRfcQChRLo1bPXQeLFluMyvDhXTTA=
|
||||
@@ -115,8 +138,11 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
|
||||
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
|
||||
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
|
||||
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
@@ -127,6 +153,7 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
@@ -179,6 +206,7 @@ github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 h1:IFhPCcB0/H
|
||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk=
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1 h1:UoQv7fBKtzAiD1qZPIvTy62Se48YLKxcCYP9nAwWMa0=
|
||||
github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYXOwye868w=
|
||||
@@ -208,6 +236,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
@@ -230,6 +259,8 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK
|
||||
github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ=
|
||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
|
||||
github.com/ysmood/fetchup v0.3.0 h1:UhYz9xnLEVn2ukSuK3KCgcznWpHMdrmbsPpllcylyu8=
|
||||
@@ -260,8 +291,8 @@ go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
|
||||
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec=
|
||||
go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
|
||||
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
|
||||
go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -275,15 +306,23 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
@@ -294,12 +333,15 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -338,16 +380,39 @@ golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
@@ -360,6 +425,10 @@ gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYs
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
@@ -369,4 +438,6 @@ gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8o
|
||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
|
||||
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -66,15 +66,15 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
value := utils.JsonEncode(&data.Config)
|
||||
config := model.Config{Key: data.Key, Config: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
|
||||
config := model.Config{Name: data.Key, Value: value}
|
||||
res := h.DB.FirstOrCreate(&config, model.Config{Name: data.Key})
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Config = value
|
||||
config.Value = value
|
||||
res := h.DB.Updates(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
@@ -83,16 +83,16 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
|
||||
// update config cache for AppServer
|
||||
var cfg model.Config
|
||||
h.DB.Where("marker", data.Key).First(&cfg)
|
||||
h.DB.Where("name", data.Key).First(&cfg)
|
||||
var err error
|
||||
if data.Key == "system" {
|
||||
err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
|
||||
err = utils.JsonDecode(cfg.Value, &h.App.SysConfig)
|
||||
}
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Failed to update config cache: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.Infof("Update AppServer's config successfully: %v", config.Config)
|
||||
logger.Infof("Update AppServer's config successfully: %v", config.Value)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, config)
|
||||
@@ -102,14 +102,14 @@ func (h *ConfigHandler) Update(c *gin.Context) {
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
res := h.DB.Where("name", key).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]interface{}
|
||||
err := utils.JsonDecode(config.Config, &value)
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -194,7 +194,6 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "sd":
|
||||
var job model.SdJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -210,7 +209,6 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
case "dall":
|
||||
var job model.DallJob
|
||||
if res := h.DB.Where("id", id).First(&job); res.Error != nil {
|
||||
@@ -226,7 +224,6 @@ func (h *ImageHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
imgURL = job.ImgURL
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
|
||||
296
api/handler/admin/jimeng_handler.go
Normal file
296
api/handler/admin/jimeng_handler.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/oss"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminJimengHandler 管理后台即梦AI处理器
|
||||
type AdminJimengHandler struct {
|
||||
handler.BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewAdminJimengHandler 创建管理后台即梦AI处理器
|
||||
func NewAdminJimengHandler(app *core.AppServer, db *gorm.DB, jimengService *jimeng.Service, userService *service.UserService, uploader *oss.UploaderManager) *AdminJimengHandler {
|
||||
return &AdminJimengHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册即梦AI管理后台路由
|
||||
func (h *AdminJimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/admin/jimeng/")
|
||||
rg.GET("/jobs", h.Jobs)
|
||||
rg.GET("/jobs/:id", h.JobDetail)
|
||||
rg.POST("/jobs/remove", h.BatchRemove)
|
||||
rg.GET("/stats", h.Stats)
|
||||
rg.GET("/config", h.GetConfig)
|
||||
rg.POST("/config/update", h.UpdateConfig)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *AdminJimengHandler) Jobs(c *gin.Context) {
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
taskType := h.GetTrim(c, "type")
|
||||
status := h.GetTrim(c, "status")
|
||||
|
||||
var tasks []model.JimengJob
|
||||
var total int64
|
||||
|
||||
session := h.DB.Model(&model.JimengJob{})
|
||||
|
||||
// 构建查询条件
|
||||
if userId > 0 {
|
||||
session = session.Where("user_id = ?", userId)
|
||||
}
|
||||
if taskType != "" {
|
||||
session = session.Where("type = ?", taskType)
|
||||
}
|
||||
if status != "" {
|
||||
session = session.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := session.Count(&total).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务数量失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = session.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&tasks).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取任务列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"jobs": tasks,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// JobDetail 获取任务详情
|
||||
func (h *AdminJimengHandler) JobDetail(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
jobId, err := strconv.ParseUint(idStr, 10, 32)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var job model.JimengJob
|
||||
err = h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// BatchRemove 批量删除任务
|
||||
func (h *AdminJimengHandler) BatchRemove(c *gin.Context) {
|
||||
var req struct {
|
||||
JobIds []uint `json:"job_ids" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
var deletedCount int64 = 0
|
||||
for _, jobId := range req.JobIds {
|
||||
var job model.JimengJob
|
||||
err := h.DB.Where("id = ?", jobId).First(&job).Error
|
||||
if err != nil {
|
||||
continue // 跳过不存在的
|
||||
}
|
||||
tx := h.DB.Begin()
|
||||
if job.Status != model.JMTaskStatusSuccess && job.Power > 0 {
|
||||
remark := fmt.Sprintf("任务未成功,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
err = h.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: remark,
|
||||
})
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = tx.Where("id = ?", jobId).Delete(&model.JimengJob{}).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
continue
|
||||
}
|
||||
tx.Commit()
|
||||
deletedCount++
|
||||
if job.ImgURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
|
||||
if err != nil {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
}
|
||||
if job.VideoURL != "" {
|
||||
err = h.uploader.GetUploadHandler().Delete(job.VideoURL)
|
||||
if err != nil {
|
||||
logger.Error("remove video failed: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"message": "批量删除成功",
|
||||
"deleted_count": deletedCount,
|
||||
})
|
||||
}
|
||||
|
||||
// Stats 获取统计信息
|
||||
func (h *AdminJimengHandler) Stats(c *gin.Context) {
|
||||
type StatResult struct {
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := h.DB.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "获取统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 整理统计数据
|
||||
result := gin.H{
|
||||
"totalTasks": int64(0),
|
||||
"completedTasks": int64(0),
|
||||
"processingTasks": int64(0),
|
||||
"failedTasks": int64(0),
|
||||
"pendingTasks": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["totalTasks"] = result["totalTasks"].(int64) + stat.Count
|
||||
switch stat.Status {
|
||||
case model.JMTaskStatusInQueue:
|
||||
result["pendingTasks"] = stat.Count
|
||||
case model.JMTaskStatusSuccess:
|
||||
result["completedTasks"] = stat.Count
|
||||
case model.JMTaskStatusGenerating:
|
||||
result["processingTasks"] = stat.Count
|
||||
case model.JMTaskStatusFailed:
|
||||
result["failedTasks"] = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, result)
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (h *AdminJimengHandler) GetConfig(c *gin.Context) {
|
||||
jimengConfig := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, jimengConfig)
|
||||
}
|
||||
|
||||
// UpdateConfig 更新即梦AI配置
|
||||
func (h *AdminJimengHandler) UpdateConfig(c *gin.Context) {
|
||||
var req types.JimengConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if req.AccessKey == "" {
|
||||
resp.ERROR(c, "AccessKey不能为空")
|
||||
return
|
||||
}
|
||||
if req.SecretKey == "" {
|
||||
resp.ERROR(c, "SecretKey不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证算力配置
|
||||
if req.Power.TextToImage <= 0 {
|
||||
resp.ERROR(c, "文生图算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageToImage <= 0 {
|
||||
resp.ERROR(c, "图生图算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageEdit <= 0 {
|
||||
resp.ERROR(c, "图片编辑算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageEffects <= 0 {
|
||||
resp.ERROR(c, "图片特效算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.TextToVideo <= 0 {
|
||||
resp.ERROR(c, "文生视频算力必须大于0")
|
||||
return
|
||||
}
|
||||
if req.Power.ImageToVideo <= 0 {
|
||||
resp.ERROR(c, "图生视频算力必须大于0")
|
||||
return
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
tx := h.DB.Begin()
|
||||
value := utils.JsonEncode(&req)
|
||||
config := model.Config{Name: "jimeng", Value: value}
|
||||
|
||||
err := tx.FirstOrCreate(&config, model.Config{Name: "jimeng"}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "保存配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if config.Id > 0 {
|
||||
config.Value = value
|
||||
err = tx.Updates(&config).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新服务中的客户端配置
|
||||
updateErr := h.jimengService.UpdateClientConfig(req.AccessKey, req.SecretKey)
|
||||
if updateErr != nil {
|
||||
resp.ERROR(c, updateErr.Error())
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "配置更新成功"})
|
||||
}
|
||||
@@ -154,7 +154,6 @@ func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
remark = fmt.Sprintf("SUNO 任务失败,退回算力。任务ID:%d,Err: %s", job.Id, job.ErrMsg)
|
||||
progress = job.Progress
|
||||
fileURL = job.AudioURL
|
||||
break
|
||||
case "luma":
|
||||
case "keling":
|
||||
var job model.VideoJob
|
||||
@@ -174,7 +173,6 @@ func (h *MediaHandler) Remove(c *gin.Context) {
|
||||
if fileURL == "" {
|
||||
fileURL = job.WaterURL
|
||||
}
|
||||
break
|
||||
default:
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
|
||||
@@ -95,6 +95,15 @@ func (h *ChatHandler) Chat(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
// 这里做个全局的异常处理,防止整个请求异常,导致 SSE 连接断开
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("chat handler error: %v", err)
|
||||
pushMessage(c, ChatEventError, err)
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
// 使用旧的聊天数据覆盖模型和角色ID
|
||||
var chat model.ChatItem
|
||||
h.DB.Where("chat_id", input.ChatId).First(&chat)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
// List 获取会话列表
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
logger.Info(h.GetLoginUserId(c))
|
||||
if !h.IsLogin(c) {
|
||||
resp.SUCCESS(c)
|
||||
return
|
||||
@@ -28,7 +29,7 @@ func (h *ChatHandler) List(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
var items = make([]vo.ChatItem, 0)
|
||||
var chats []model.ChatItem
|
||||
h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
h.DB.Debug().Where("user_id", userId).Order("id DESC").Find(&chats)
|
||||
if len(chats) == 0 {
|
||||
resp.SUCCESS(c, items)
|
||||
return
|
||||
|
||||
@@ -31,14 +31,14 @@ func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.
|
||||
func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
var config model.Config
|
||||
res := h.DB.Where("marker", key).First(&config)
|
||||
res := h.DB.Where("name", key).First(&config)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var value map[string]interface{}
|
||||
err := utils.JsonDecode(config.Config, &value)
|
||||
var value map[string]any
|
||||
err := utils.JsonDecode(config.Value, &value)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
|
||||
@@ -77,7 +77,7 @@ func (h *DallJobHandler) Image(c *gin.Context) {
|
||||
Quality: data.Quality,
|
||||
Size: data.Size,
|
||||
Style: data.Style,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Power: chatModel.Power,
|
||||
}
|
||||
job := model.DallJob{
|
||||
|
||||
@@ -213,7 +213,7 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
|
||||
Prompt: prompt,
|
||||
ModelId: 0,
|
||||
ModelName: "dall-e-3",
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
N: 1,
|
||||
Quality: "standard",
|
||||
Size: "1024x1024",
|
||||
@@ -265,27 +265,27 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 从参数中获取搜索关键词
|
||||
keyword, ok := params["keyword"].(string)
|
||||
if !ok || keyword == "" {
|
||||
resp.ERROR(c, "搜索关键词不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 从参数中获取最大页数,默认为1页
|
||||
maxPages := 1
|
||||
if pages, ok := params["max_pages"].(float64); ok {
|
||||
maxPages = int(pages)
|
||||
}
|
||||
|
||||
|
||||
// 获取用户ID
|
||||
userID, ok := params["user_id"].(float64)
|
||||
if !ok {
|
||||
resp.ERROR(c, "用户ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 查询用户信息
|
||||
var user model.User
|
||||
res := h.DB.Where("id = ?", int(userID)).First(&user)
|
||||
@@ -293,21 +293,21 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
resp.ERROR(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 检查用户算力是否足够
|
||||
searchPower := 1 // 每次搜索消耗1点算力
|
||||
if user.Power < searchPower {
|
||||
resp.ERROR(c, "算力不足,无法执行网络搜索")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 执行网络搜索
|
||||
searchResults, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("搜索失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 扣减用户算力
|
||||
err = h.userService.DecreasePower(user.Id, searchPower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
@@ -318,7 +318,7 @@ func (h *FunctionHandler) WebSearch(c *gin.Context) {
|
||||
resp.ERROR(c, "扣减算力失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 返回搜索结果
|
||||
resp.SUCCESS(c, searchResults)
|
||||
}
|
||||
|
||||
442
api/handler/jimeng_handler.go
Normal file
442
api/handler/jimeng_handler.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/store/model"
|
||||
"geekai/store/vo"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JimengHandler 即梦AI处理器
|
||||
type JimengHandler struct {
|
||||
BaseHandler
|
||||
jimengService *jimeng.Service
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewJimengHandler 创建即梦AI处理器
|
||||
func NewJimengHandler(app *core.AppServer, jimengService *jimeng.Service, db *gorm.DB, userService *service.UserService) *JimengHandler {
|
||||
return &JimengHandler{
|
||||
BaseHandler: BaseHandler{App: app, DB: db},
|
||||
jimengService: jimengService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由,新增统一任务接口
|
||||
func (h *JimengHandler) RegisterRoutes() {
|
||||
rg := h.App.Engine.Group("/api/jimeng")
|
||||
rg.POST("task", h.CreateTask) // 只保留统一任务接口
|
||||
rg.GET("power-config", h.GetPowerConfig) // 新增算力配置接口
|
||||
rg.POST("jobs", h.Jobs)
|
||||
rg.GET("remove", h.Remove)
|
||||
rg.GET("retry", h.Retry)
|
||||
}
|
||||
|
||||
// JimengTaskRequest 统一任务请求结构体
|
||||
// 支持所有生图和生成视频类型
|
||||
type JimengTaskRequest struct {
|
||||
TaskType string `json:"task_type" binding:"required"`
|
||||
Prompt string `json:"prompt"`
|
||||
ImageInput string `json:"image_input"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
Scale float64 `json:"scale"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Gpen float64 `json:"gpen"`
|
||||
Skin float64 `json:"skin"`
|
||||
SkinUnifi float64 `json:"skin_unifi"`
|
||||
GenMode string `json:"gen_mode"`
|
||||
Seed int64 `json:"seed"`
|
||||
UsePreLLM bool `json:"use_pre_llm"`
|
||||
TemplateId string `json:"template_id"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
// CreateTask 统一任务创建接口
|
||||
func (h *JimengHandler) CreateTask(c *gin.Context) {
|
||||
var req JimengTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
// 新增:除图像特效外,其他任务类型必须有提示词
|
||||
if req.TaskType != "image_effects" && req.Prompt == "" {
|
||||
resp.ERROR(c, "提示词不能为空")
|
||||
return
|
||||
}
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
|
||||
var powerCost int
|
||||
var taskType model.JMTaskType
|
||||
var params map[string]any
|
||||
var reqKey string
|
||||
var modelName string
|
||||
|
||||
switch req.TaskType {
|
||||
case "text_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToImage)
|
||||
taskType = model.JMTaskTypeTextToImage
|
||||
reqKey = jimeng.ReqKeyTextToImage
|
||||
modelName = "即梦文生图"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 2.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"use_pre_llm": req.UsePreLLM,
|
||||
}
|
||||
case "image_to_image":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToImage)
|
||||
taskType = model.JMTaskTypeImageToImage
|
||||
reqKey = jimeng.ReqKeyImageToImagePortrait
|
||||
modelName = "即梦图生图"
|
||||
if req.Gpen == 0 {
|
||||
req.Gpen = 0.4
|
||||
}
|
||||
if req.Skin == 0 {
|
||||
req.Skin = 0.3
|
||||
}
|
||||
if req.GenMode == "" {
|
||||
if req.Prompt != "" {
|
||||
req.GenMode = jimeng.GenModeCreative
|
||||
} else {
|
||||
req.GenMode = jimeng.GenModeReference
|
||||
}
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input": req.ImageInput,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
"gpen": req.Gpen,
|
||||
"skin": req.Skin,
|
||||
"skin_unifi": req.SkinUnifi,
|
||||
"gen_mode": req.GenMode,
|
||||
"seed": req.Seed,
|
||||
}
|
||||
case "image_edit":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEdit)
|
||||
taskType = model.JMTaskTypeImageEdit
|
||||
reqKey = jimeng.ReqKeyImageEdit
|
||||
modelName = "即梦图像编辑"
|
||||
if req.Scale == 0 {
|
||||
req.Scale = 0.5
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"scale": req.Scale,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
case "image_effects":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageEffects)
|
||||
taskType = model.JMTaskTypeImageEffects
|
||||
reqKey = jimeng.ReqKeyImageEffects
|
||||
modelName = "即梦图像特效"
|
||||
if req.Width == 0 {
|
||||
req.Width = 1328
|
||||
}
|
||||
if req.Height == 0 {
|
||||
req.Height = 1328
|
||||
}
|
||||
params = map[string]any{
|
||||
"image_input1": req.ImageInput,
|
||||
"template_id": req.TemplateId,
|
||||
"width": req.Width,
|
||||
"height": req.Height,
|
||||
}
|
||||
case "text_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeTextToVideo)
|
||||
taskType = model.JMTaskTypeTextToVideo
|
||||
reqKey = jimeng.ReqKeyTextToVideo
|
||||
modelName = "即梦文生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
if req.AspectRatio == "" {
|
||||
req.AspectRatio = jimeng.AspectRatio16_9
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
case "image_to_video":
|
||||
powerCost = h.getPowerFromConfig(model.JMTaskTypeImageToVideo)
|
||||
taskType = model.JMTaskTypeImageToVideo
|
||||
reqKey = jimeng.ReqKeyImageToVideo
|
||||
modelName = "即梦图生视频"
|
||||
if req.Seed == 0 {
|
||||
req.Seed = -1
|
||||
}
|
||||
params = map[string]any{
|
||||
"seed": req.Seed,
|
||||
"aspect_ratio": req.AspectRatio,
|
||||
}
|
||||
if len(req.ImageUrls) > 0 {
|
||||
params["image_urls"] = req.ImageUrls
|
||||
}
|
||||
if len(req.BinaryDataBase64) > 0 {
|
||||
params["binary_data_base64"] = req.BinaryDataBase64
|
||||
}
|
||||
default:
|
||||
resp.ERROR(c, "不支持的任务类型")
|
||||
return
|
||||
}
|
||||
|
||||
if user.Power < powerCost {
|
||||
resp.ERROR(c, fmt.Sprintf("算力不足,需要%d算力", powerCost))
|
||||
return
|
||||
}
|
||||
|
||||
taskReq := &jimeng.CreateTaskRequest{
|
||||
Type: taskType,
|
||||
Prompt: req.Prompt,
|
||||
Params: params,
|
||||
ReqKey: reqKey,
|
||||
Power: powerCost,
|
||||
}
|
||||
|
||||
job, err := h.jimengService.CreateTask(user.Id, taskReq)
|
||||
if err != nil {
|
||||
logger.Errorf("create jimeng task failed: %v", err)
|
||||
resp.ERROR(c, "创建任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.userService.DecreasePower(user.Id, powerCost, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("%s,任务ID:%d", modelName, job.Id),
|
||||
})
|
||||
|
||||
resp.SUCCESS(c, job)
|
||||
}
|
||||
|
||||
// Jobs 获取任务列表
|
||||
func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
var req struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Filter string `json:"filter"`
|
||||
Ids []uint `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
var jobs []model.JimengJob
|
||||
var total int64
|
||||
query := h.DB.Model(&model.JimengJob{}).Where("user_id = ?", userId)
|
||||
|
||||
switch req.Filter {
|
||||
case "image":
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToImage,
|
||||
model.JMTaskTypeImageToImage,
|
||||
model.JMTaskTypeImageEdit,
|
||||
model.JMTaskTypeImageEffects,
|
||||
})
|
||||
case "video":
|
||||
query = query.Where("type IN (?)", []model.JMTaskType{
|
||||
model.JMTaskTypeTextToVideo,
|
||||
model.JMTaskTypeImageToVideo,
|
||||
})
|
||||
}
|
||||
|
||||
if len(req.Ids) > 0 {
|
||||
query = query.Where("id IN (?)", req.Ids)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := query.Order("updated_at DESC").Offset(offset).Limit(req.PageSize).Find(&jobs).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 填充 VO
|
||||
var jobVos []vo.JimengJob
|
||||
for _, job := range jobs {
|
||||
var jobVo vo.JimengJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
jobVo.CreatedAt = job.CreatedAt.Unix()
|
||||
jobVos = append(jobVos, jobVo)
|
||||
}
|
||||
resp.SUCCESS(c, vo.NewPage(total, req.Page, req.PageSize, jobVos))
|
||||
}
|
||||
|
||||
// Remove 删除任务
|
||||
func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取任务,判断状态
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
if job.UserId != user.Id {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能删除")
|
||||
return
|
||||
}
|
||||
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Where("id = ? AND user_id = ?", jobId, user.Id).Delete(&model.JimengJob{}).Error; err != nil {
|
||||
logger.Errorf("delete jimeng job failed: %v", err)
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 退回算力
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "退回算力失败")
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// Retry 重试任务
|
||||
func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
resp.ERROR(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否存在且属于当前用户
|
||||
job, err := h.jimengService.GetJob(uint(jobId))
|
||||
if err != nil {
|
||||
resp.ERROR(c, "任务不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if job.UserId != userId {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败的任务才能重试
|
||||
if job.Status != model.JMTaskStatusFailed {
|
||||
resp.ERROR(c, "只有失败的任务才能重试")
|
||||
return
|
||||
}
|
||||
|
||||
// 重置任务状态
|
||||
if err := h.jimengService.UpdateJobStatus(uint(jobId), model.JMTaskStatusInQueue, ""); err != nil {
|
||||
logger.Errorf("reset job status failed: %v", err)
|
||||
resp.ERROR(c, "重置任务状态失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 重新推送到队列
|
||||
if err := h.jimengService.PushTaskToQueue(uint(jobId)); err != nil {
|
||||
logger.Errorf("push retry task to queue failed: %v", err)
|
||||
resp.ERROR(c, "推送重试任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, gin.H{"message": "重试任务已提交"})
|
||||
}
|
||||
|
||||
// getPowerFromConfig 从配置中获取指定类型的算力消耗
|
||||
func (h *JimengHandler) getPowerFromConfig(taskType model.JMTaskType) int {
|
||||
config := h.jimengService.GetConfig()
|
||||
|
||||
switch taskType {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
return config.Power.TextToImage
|
||||
case model.JMTaskTypeImageToImage:
|
||||
return config.Power.ImageToImage
|
||||
case model.JMTaskTypeImageEdit:
|
||||
return config.Power.ImageEdit
|
||||
case model.JMTaskTypeImageEffects:
|
||||
return config.Power.ImageEffects
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
return config.Power.TextToVideo
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
return config.Power.ImageToVideo
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
}
|
||||
|
||||
// GetPowerConfig 获取即梦各任务类型算力消耗配置
|
||||
func (h *JimengHandler) GetPowerConfig(c *gin.Context) {
|
||||
config := h.jimengService.GetConfig()
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"text_to_image": config.Power.TextToImage,
|
||||
"image_to_image": config.Power.ImageToImage,
|
||||
"image_edit": config.Power.ImageEdit,
|
||||
"image_effects": config.Power.ImageEffects,
|
||||
"text_to_video": config.Power.TextToVideo,
|
||||
"image_to_video": config.Power.ImageToVideo,
|
||||
})
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: data.TaskType,
|
||||
|
||||
@@ -144,7 +144,15 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 使用http.Get下载文件
|
||||
r, err := http.Get(fileUrl)
|
||||
req, err := http.NewRequest("GET", fileUrl, nil)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 模拟浏览器 UA
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
client := &http.Client{}
|
||||
r, err := client.Do(req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -157,6 +165,5 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
// 将下载的文件内容写入响应
|
||||
_, _ = io.Copy(c.Writer, r.Body)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (h *PromptHandler) Lyric(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.LyricPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -79,7 +79,7 @@ func (h *PromptHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.ImagePromptOptimizeTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -108,7 +108,7 @@ func (h *PromptHandler) Video(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.TranslateModelId)
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(service.VideoPromptTemplate, data.Prompt), h.App.SysConfig.AssistantModelId)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -158,9 +158,9 @@ func (h *PromptHandler) MetaPrompt(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *PromptHandler) getPromptModel() string {
|
||||
if h.App.SysConfig.TranslateModelId > 0 {
|
||||
if h.App.SysConfig.AssistantModelId > 0 {
|
||||
var chatModel model.ChatModel
|
||||
h.DB.Where("id", h.App.SysConfig.TranslateModelId).First(&chatModel)
|
||||
h.DB.Where("id", h.App.SysConfig.AssistantModelId).First(&chatModel)
|
||||
return chatModel.Value
|
||||
}
|
||||
return "gpt-4o"
|
||||
|
||||
@@ -131,7 +131,7 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
HdSteps: data.HdSteps,
|
||||
},
|
||||
UserId: userId,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
|
||||
job := model.SdJob{
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *VideoHandler) LumaCreate(c *gin.Context) {
|
||||
Type: types.VideoLuma,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
}
|
||||
// 插入数据库
|
||||
job := model.VideoJob{
|
||||
@@ -181,7 +181,7 @@ func (h *VideoHandler) KeLingCreate(c *gin.Context) {
|
||||
Type: types.VideoKeLing,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
TranslateModelId: h.App.SysConfig.TranslateModelId,
|
||||
TranslateModelId: h.App.SysConfig.AssistantModelId,
|
||||
Channel: data.Channel,
|
||||
}
|
||||
// 插入数据库
|
||||
|
||||
17
api/main.go
17
api/main.go
@@ -17,6 +17,7 @@ import (
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service"
|
||||
"geekai/service/dalle"
|
||||
"geekai/service/jimeng"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/oss"
|
||||
"geekai/service/payment"
|
||||
@@ -140,6 +141,7 @@ func main() {
|
||||
fx.Provide(handler.NewProductHandler),
|
||||
fx.Provide(handler.NewConfigHandler),
|
||||
fx.Provide(handler.NewPowerLogHandler),
|
||||
fx.Provide(handler.NewJimengHandler),
|
||||
|
||||
fx.Provide(admin.NewConfigHandler),
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
@@ -153,6 +155,7 @@ func main() {
|
||||
fx.Provide(admin.NewOrderHandler),
|
||||
fx.Provide(admin.NewChatHandler),
|
||||
fx.Provide(admin.NewPowerLogHandler),
|
||||
fx.Provide(admin.NewAdminJimengHandler),
|
||||
|
||||
// 创建服务
|
||||
fx.Provide(sms.NewSendServiceManager),
|
||||
@@ -208,6 +211,12 @@ func main() {
|
||||
s.SyncTaskProgress()
|
||||
s.DownloadFiles()
|
||||
}),
|
||||
|
||||
// 即梦AI 服务
|
||||
fx.Provide(jimeng.NewService),
|
||||
fx.Invoke(func(service *jimeng.Service) {
|
||||
service.Start()
|
||||
}),
|
||||
fx.Provide(service.NewUserService),
|
||||
fx.Provide(payment.NewAlipayService),
|
||||
fx.Provide(payment.NewHuPiPay),
|
||||
@@ -501,6 +510,14 @@ func main() {
|
||||
group.GET("remove", h.Remove)
|
||||
group.GET("publish", h.Publish)
|
||||
}),
|
||||
|
||||
// 即梦AI 路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.JimengHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.AdminJimengHandler) {
|
||||
h.RegisterRoutes()
|
||||
}),
|
||||
fx.Provide(admin.NewChatAppTypeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ChatAppTypeHandler) {
|
||||
group := s.Engine.Group("/api/admin/app/type")
|
||||
|
||||
@@ -49,7 +49,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.DallTask) {
|
||||
logger.Infof("add a new DALL-E task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push dall-e task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
@@ -291,7 +293,7 @@ func (s *Service) DownloadImages() {
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, ".png", false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
139
api/service/jimeng/client.go
Normal file
139
api/service/jimeng/client.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
)
|
||||
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(accessKey, secretKey string) *Client {
|
||||
// 使用官方SDK的visual实例
|
||||
visualInstance := visual.NewInstance()
|
||||
visualInstance.Client.SetAccessKey(accessKey)
|
||||
visualInstance.Client.SetSecretKey(secretKey)
|
||||
|
||||
// 添加即梦AI专有的API配置
|
||||
jimengApis := map[string]*base.ApiInfo{
|
||||
"CVSync2AsyncSubmitTask": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSync2AsyncSubmitTask"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVSync2AsyncGetResult": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVSync2AsyncGetResult"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
"CVProcess": {
|
||||
Method: http.MethodPost,
|
||||
Path: "/",
|
||||
Query: url.Values{
|
||||
"Action": []string{"CVProcess"},
|
||||
"Version": []string{"2022-08-31"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 将即梦API添加到现有的ApiInfoList中
|
||||
for name, info := range jimengApis {
|
||||
visualInstance.Client.ApiInfoList[name] = info
|
||||
}
|
||||
|
||||
return &Client{
|
||||
visual: visualInstance,
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *Client) SubmitTask(req *SubmitTaskRequest) (*SubmitTaskResponse, error) {
|
||||
// 直接将请求转为map[string]interface{}
|
||||
reqBodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 直接使用序列化后的字节
|
||||
jsonBody := reqBodyBytes
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncSubmitTask", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result SubmitTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// QueryTask 查询任务结果
|
||||
func (c *Client) QueryTask(req *QueryTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVSync2AsyncGetResult", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng QueryTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// SubmitSyncTask 提交同步任务(仅用于文生图)
|
||||
func (c *Client) SubmitSyncTask(req *SubmitTaskRequest) (*QueryTaskResponse, error) {
|
||||
// 序列化请求
|
||||
jsonBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用SDK的JSON方法
|
||||
respBody, statusCode, err := c.visual.Client.Json("CVProcess", nil, string(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("submit sync task failed (status: %d): %w", statusCode, err)
|
||||
}
|
||||
|
||||
logger.Infof("Jimeng SubmitSyncTask Response: %s", string(respBody))
|
||||
|
||||
// 解析响应,同步任务直接返回结果
|
||||
var result QueryTaskResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
600
api/service/jimeng/service.go
Normal file
600
api/service/jimeng/service.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// Service 即梦服务(合并了消费者功能)
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
taskQueue *store.RedisQueue
|
||||
client *Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
running bool
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
redis: redisCli,
|
||||
taskQueue: taskQueue,
|
||||
client: client,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: false,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动服务(包含消费者)
|
||||
func (s *Service) Start() {
|
||||
if s.running {
|
||||
return
|
||||
}
|
||||
logger.Info("Starting Jimeng service and task consumer...")
|
||||
s.running = true
|
||||
go s.consumeTasks()
|
||||
go s.pollTaskStatus()
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (s *Service) Stop() {
|
||||
if !s.running {
|
||||
return
|
||||
}
|
||||
logger.Info("Stopping Jimeng service and task consumer...")
|
||||
s.running = false
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// consumeTasks 消费任务
|
||||
func (s *Service) consumeTasks() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
logger.Info("Jimeng task consumer stopped")
|
||||
return
|
||||
default:
|
||||
s.processNextTask()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processNextTask 处理下一个任务
|
||||
func (s *Service) processNextTask() {
|
||||
var jobId uint
|
||||
if err := s.taskQueue.LPop(&jobId); err != nil {
|
||||
// 队列为空,等待1秒后重试
|
||||
time.Sleep(time.Second)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("Processing Jimeng task: job_id=%d", jobId)
|
||||
|
||||
if err := s.ProcessTask(jobId); err != nil {
|
||||
logger.Errorf("process jimeng task failed: job_id=%d, error=%v", jobId, err)
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, err.Error())
|
||||
} else {
|
||||
logger.Infof("Jimeng task processed successfully: job_id=%d", jobId)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (s *Service) CreateTask(userId uint, req *CreateTaskRequest) (*model.JimengJob, error) {
|
||||
// 生成任务ID
|
||||
taskId := utils.RandString(20)
|
||||
|
||||
// 序列化任务参数
|
||||
paramsJson, err := json.Marshal(req.Params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务记录
|
||||
job := &model.JimengJob{
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Type: req.Type,
|
||||
ReqKey: req.ReqKey,
|
||||
Prompt: req.Prompt,
|
||||
TaskParams: string(paramsJson),
|
||||
Status: model.JMTaskStatusInQueue,
|
||||
Power: req.Power,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
if err := s.db.Create(job).Error; err != nil {
|
||||
return nil, fmt.Errorf("create jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 推送到任务队列
|
||||
if err := s.taskQueue.RPush(job.Id); err != nil {
|
||||
return nil, fmt.Errorf("push jimeng task to queue failed: %w", err)
|
||||
}
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// ProcessTask 处理任务
|
||||
func (s *Service) ProcessTask(jobId uint) error {
|
||||
// 获取任务记录
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return fmt.Errorf("get jimeng job failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新任务状态为处理中
|
||||
if err := s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, ""); err != nil {
|
||||
return fmt.Errorf("update job status failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建请求并提交任务
|
||||
req, err := s.buildTaskRequest(&job)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||
}
|
||||
|
||||
logger.Infof("提交即梦任务: %+v", req)
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %v", err))
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("submit task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
// 更新任务ID和原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
if err := s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(map[string]any{
|
||||
"task_id": resp.Data.TaskId,
|
||||
"raw_data": string(rawData),
|
||||
"updated_at": time.Now(),
|
||||
}).Error; err != nil {
|
||||
logger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||
func (s *Service) buildTaskRequest(job *model.JimengJob) (*SubmitTaskRequest, error) {
|
||||
// 解析任务参数
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal([]byte(job.TaskParams), ¶ms); err != nil {
|
||||
return nil, fmt.Errorf("parse task params failed: %w", err)
|
||||
}
|
||||
|
||||
// 构建基础请求
|
||||
req := &SubmitTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
Prompt: job.Prompt,
|
||||
}
|
||||
|
||||
// 根据任务类型设置特定参数
|
||||
switch job.Type {
|
||||
case model.JMTaskTypeTextToImage:
|
||||
s.setTextToImageParams(req, params)
|
||||
case model.JMTaskTypeImageToImage:
|
||||
s.setImageToImageParams(req, params)
|
||||
case model.JMTaskTypeImageEdit:
|
||||
s.setImageEditParams(req, params)
|
||||
case model.JMTaskTypeImageEffects:
|
||||
s.setImageEffectsParams(req, params)
|
||||
case model.JMTaskTypeTextToVideo:
|
||||
s.setTextToVideoParams(req, params)
|
||||
case model.JMTaskTypeImageToVideo:
|
||||
s.setImageToVideoParams(req, params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported task type: %s", job.Type)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// setTextToImageParams 设置文生图参数
|
||||
func (s *Service) setTextToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
if usePreLlm, ok := params["use_pre_llm"]; ok {
|
||||
if usePreLlmVal, ok := usePreLlm.(bool); ok {
|
||||
req.UsePreLLM = usePreLlmVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setImageToImageParams 设置图生图参数
|
||||
func (s *Service) setImageToImageParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput, ok := params["image_input"].(string); ok {
|
||||
req.ImageInput = imageInput
|
||||
}
|
||||
if gpen, ok := params["gpen"]; ok {
|
||||
if gpenVal, ok := gpen.(float64); ok {
|
||||
req.Gpen = gpenVal
|
||||
}
|
||||
}
|
||||
if skin, ok := params["skin"]; ok {
|
||||
if skinVal, ok := skin.(float64); ok {
|
||||
req.Skin = skinVal
|
||||
}
|
||||
}
|
||||
if skinUnifi, ok := params["skin_unifi"]; ok {
|
||||
if skinUnifiVal, ok := skinUnifi.(float64); ok {
|
||||
req.SkinUnifi = skinUnifiVal
|
||||
}
|
||||
}
|
||||
if genMode, ok := params["gen_mode"].(string); ok {
|
||||
req.GenMode = genMode
|
||||
}
|
||||
s.setCommonParams(req, params) // 复用通用参数
|
||||
}
|
||||
|
||||
// setImageEditParams 设置图像编辑参数
|
||||
func (s *Service) setImageEditParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageUrls, ok := params["image_urls"].([]any); ok {
|
||||
for _, url := range imageUrls {
|
||||
if urlStr, ok := url.(string); ok {
|
||||
req.ImageUrls = append(req.ImageUrls, urlStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if binaryData, ok := params["binary_data_base64"].([]any); ok {
|
||||
for _, data := range binaryData {
|
||||
if dataStr, ok := data.(string); ok {
|
||||
req.BinaryDataBase64 = append(req.BinaryDataBase64, dataStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if scale, ok := params["scale"]; ok {
|
||||
if scaleVal, ok := scale.(float64); ok {
|
||||
req.Scale = scaleVal
|
||||
}
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageEffectsParams 设置图像特效参数
|
||||
func (s *Service) setImageEffectsParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if imageInput1, ok := params["image_input1"].(string); ok {
|
||||
req.ImageInput1 = imageInput1
|
||||
}
|
||||
if templateId, ok := params["template_id"].(string); ok {
|
||||
req.TemplateId = templateId
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setTextToVideoParams 设置文生视频参数
|
||||
func (s *Service) setTextToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
s.setCommonParams(req, params)
|
||||
}
|
||||
|
||||
// setImageToVideoParams 设置图生视频参数
|
||||
func (s *Service) setImageToVideoParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
s.setImageEditParams(req, params) // 复用图像编辑的参数设置
|
||||
if aspectRatio, ok := params["aspect_ratio"].(string); ok {
|
||||
req.AspectRatio = aspectRatio
|
||||
}
|
||||
}
|
||||
|
||||
// setCommonParams 设置通用参数(seed, width, height等)
|
||||
func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any) {
|
||||
if seed, ok := params["seed"]; ok {
|
||||
if seedVal, err := strconv.ParseInt(fmt.Sprintf("%.0f", seed), 10, 64); err == nil {
|
||||
req.Seed = seedVal
|
||||
}
|
||||
}
|
||||
if width, ok := params["width"]; ok {
|
||||
if widthVal, ok := width.(float64); ok {
|
||||
req.Width = int(widthVal)
|
||||
}
|
||||
}
|
||||
if height, ok := params["height"]; ok {
|
||||
if heightVal, ok := height.(float64); ok {
|
||||
req.Height = int(heightVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
func (s *Service) pollTaskStatus() {
|
||||
|
||||
for {
|
||||
var jobs []model.JimengJob
|
||||
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
|
||||
if len(jobs) == 0 {
|
||||
logger.Debugf("no jimeng task to poll, sleep 10s")
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, job := range jobs {
|
||||
// 任务超时处理
|
||||
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
|
||||
s.handleTaskError(job.Id, "task timeout")
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
TaskId: job.TaskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("query jimeng task status failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
continue
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
continue
|
||||
}
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 设置结果URL
|
||||
if len(resp.Data.ImageUrls) > 0 {
|
||||
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.ImageUrls[0], ".png", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload image failed: %v", err)
|
||||
imgUrl = resp.Data.ImageUrls[0]
|
||||
}
|
||||
updates["img_url"] = imgUrl
|
||||
}
|
||||
if resp.Data.VideoUrl != "" {
|
||||
videoUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.VideoUrl, ".mp4", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload video failed: %v", err)
|
||||
videoUrl = resp.Data.VideoUrl
|
||||
}
|
||||
updates["video_url"] = videoUrl
|
||||
}
|
||||
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
// 任务未找到
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
// 任务过期
|
||||
s.handleTaskError(job.Id, "task expired")
|
||||
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
func (s *Service) UpdateJobStatus(jobId uint, status model.JMTaskStatus, errMsg string) error {
|
||||
updates := map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
if errMsg != "" {
|
||||
updates["err_msg"] = errMsg
|
||||
}
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||
}
|
||||
|
||||
// handleTaskError 处理任务错误
|
||||
func (s *Service) handleTaskError(jobId uint, errMsg string) error {
|
||||
logger.Errorf("Jimeng task error (job_id: %d): %s", jobId, errMsg)
|
||||
return s.UpdateJobStatus(jobId, model.JMTaskStatusFailed, errMsg)
|
||||
}
|
||||
|
||||
// PushTaskToQueue 推送任务到队列(用于手动重试)
|
||||
func (s *Service) PushTaskToQueue(jobId uint) error {
|
||||
return s.taskQueue.RPush(jobId)
|
||||
}
|
||||
|
||||
// GetTaskStats 获取任务统计信息
|
||||
func (s *Service) GetTaskStats() (map[string]any, error) {
|
||||
type StatResult struct {
|
||||
Status string `json:"status"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
var stats []StatResult
|
||||
err := s.db.Model(&model.JimengJob{}).
|
||||
Select("status, COUNT(*) as count").
|
||||
Group("status").
|
||||
Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"total": int64(0),
|
||||
"completed": int64(0),
|
||||
"processing": int64(0),
|
||||
"failed": int64(0),
|
||||
"pending": int64(0),
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
result["total"] = result["total"].(int64) + stat.Count
|
||||
result[stat.Status] = stat.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetJob 获取任务
|
||||
func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
var job model.JimengJob
|
||||
if err := s.db.First(&job, jobId).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := testClient.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClientConfig 更新客户端配置
|
||||
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
// 创建新的客户端
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
s.client = newClient
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultPower = types.JimengPower{
|
||||
TextToImage: 20,
|
||||
ImageToImage: 20,
|
||||
ImageEdit: 20,
|
||||
ImageEffects: 20,
|
||||
TextToVideo: 300,
|
||||
ImageToVideo: 300,
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (s *Service) GetConfig() *types.JimengConfig {
|
||||
var config model.Config
|
||||
err := s.db.Where("name", "jimeng").First(&config).Error
|
||||
if err != nil {
|
||||
// 如果配置不存在,返回默认配置
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
var jimengConfig types.JimengConfig
|
||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
return &jimengConfig
|
||||
}
|
||||
145
api/service/jimeng/types.go
Normal file
145
api/service/jimeng/types.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package jimeng
|
||||
|
||||
import "geekai/store/model"
|
||||
|
||||
// ReqKey 常量定义
|
||||
const (
|
||||
ReqKeyTextToImage = "high_aes_general_v30l_zt2i" // 文生图
|
||||
ReqKeyImageToImagePortrait = "i2i_portrait_photo" // 图生图人像写真
|
||||
ReqKeyImageEdit = "seededit_v3.0" // 图像编辑
|
||||
ReqKeyImageEffects = "i2i_multi_style_zx2x" // 图像特效
|
||||
ReqKeyTextToVideo = "jimeng_vgfm_t2v_l20" // 文生视频
|
||||
ReqKeyImageToVideo = "jimeng_vgfm_i2v_l20" // 图生视频
|
||||
)
|
||||
|
||||
// SubmitTaskRequest 提交任务请求
|
||||
type SubmitTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
// 文生图参数
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Scale float64 `json:"scale,omitempty"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
UsePreLLM bool `json:"use_pre_llm,omitempty"`
|
||||
// 图生图参数
|
||||
ImageInput string `json:"image_input,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
Gpen float64 `json:"gpen,omitempty"`
|
||||
Skin float64 `json:"skin,omitempty"`
|
||||
SkinUnifi float64 `json:"skin_unifi,omitempty"`
|
||||
GenMode string `json:"gen_mode,omitempty"`
|
||||
// 图像编辑参数
|
||||
// 图像特效参数
|
||||
ImageInput1 string `json:"image_input1,omitempty"`
|
||||
TemplateId string `json:"template_id,omitempty"`
|
||||
// 视频生成参数
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitTaskResponse 提交任务响应
|
||||
type SubmitTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// QueryTaskRequest 查询任务请求
|
||||
type QueryTaskRequest struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
TaskId string `json:"task_id"`
|
||||
ReqJson string `json:"req_json,omitempty"`
|
||||
}
|
||||
|
||||
// QueryTaskResponse 查询任务响应
|
||||
type QueryTaskResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
Data struct {
|
||||
AlgorithmBaseResp struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMessage string `json:"status_message"`
|
||||
} `json:"algorithm_base_resp"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64"`
|
||||
ImageUrls []string `json:"image_urls"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
LlmResult string `json:"llm_result"`
|
||||
PeResult string `json:"pe_result"`
|
||||
PredictTagsResult string `json:"predict_tags_result"`
|
||||
RephraserResult string `json:"rephraser_result"`
|
||||
VlmResult string `json:"vlm_result"`
|
||||
InferCtx any `json:"infer_ctx"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// CreateTaskRequest 创建任务请求
|
||||
type CreateTaskRequest struct {
|
||||
Type model.JMTaskType `json:"type"`
|
||||
Prompt string `json:"prompt"`
|
||||
Params map[string]any `json:"params"`
|
||||
ReqKey string `json:"req_key"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Power int `json:"power,omitempty"`
|
||||
}
|
||||
|
||||
// LogoInfo 水印信息
|
||||
type LogoInfo struct {
|
||||
AddLogo bool `json:"add_logo"`
|
||||
Position int `json:"position"`
|
||||
Language int `json:"language"`
|
||||
Opacity float64 `json:"opacity"`
|
||||
LogoTextContent string `json:"logo_text_content"`
|
||||
}
|
||||
|
||||
// ReqJsonConfig 查询配置
|
||||
type ReqJsonConfig struct {
|
||||
ReturnUrl bool `json:"return_url"`
|
||||
LogoInfo *LogoInfo `json:"logo_info,omitempty"`
|
||||
}
|
||||
|
||||
// ImageEffectTemplate 图像特效模板
|
||||
const (
|
||||
TemplateIdFelt3DPolaroid = "felt_3d_polaroid" // 毛毡3d拍立得风格
|
||||
TemplateIdMyWorld = "my_world" // 像素世界风
|
||||
TemplateIdMyWorldUniversal = "my_world_universal" // 像素世界-万物通用版
|
||||
TemplateIdPlasticBubbleFigure = "plastic_bubble_figure" // 盲盒玩偶风
|
||||
TemplateIdPlasticBubbleFigureCartoon = "plastic_bubble_figure_cartoon_text" // 塑料泡罩人偶-文字卡头版
|
||||
TemplateIdFurryDreamDoll = "furry_dream_doll" // 毛绒玩偶风
|
||||
TemplateIdMicroLandscapeMiniWorld = "micro_landscape_mini_world" // 迷你世界玩偶风
|
||||
TemplateIdMicroLandscapeProfessional = "micro_landscape_mini_world_professional" // 微型景观小世界-职业版
|
||||
TemplateIdAcrylicOrnaments = "acrylic_ornaments" // 亚克力挂饰
|
||||
TemplateIdFeltKeychain = "felt_keychain" // 毛毡钥匙扣
|
||||
TemplateIdLofiPixelCharacter = "lofi_pixel_character_mini_card" // Lofi像素人物小卡
|
||||
TemplateIdAngelFigurine = "angel_figurine" // 天使形象手办
|
||||
TemplateIdLyingInFluffyBelly = "lying_in_fluffy_belly" // 躺在毛茸茸肚皮里
|
||||
TemplateIdGlassBall = "glass_ball" // 玻璃球
|
||||
)
|
||||
|
||||
// AspectRatio 视频宽高比
|
||||
const (
|
||||
AspectRatio16_9 = "16:9" // 1280×720
|
||||
AspectRatio9_16 = "9:16" // 720×1280
|
||||
AspectRatio1_1 = "1:1" // 960×960
|
||||
AspectRatio4_3 = "4:3" // 960×720
|
||||
AspectRatio3_4 = "3:4" // 720×960
|
||||
AspectRatio21_9 = "21:9" // 1680×720
|
||||
AspectRatio9_21 = "9:21" // 720×1680
|
||||
)
|
||||
|
||||
// GenMode 生成模式
|
||||
const (
|
||||
GenModeCreative = "creative" // 提示词模式
|
||||
GenModeReference = "reference" // 全参考模式
|
||||
GenModeReferenceChar = "reference_char" // 人物参考模式
|
||||
)
|
||||
@@ -191,7 +191,7 @@ func (s *Service) DownloadImages() {
|
||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||
proxy = true
|
||||
}
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, ".png", proxy)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||
@@ -212,7 +212,9 @@ func (s *Service) DownloadImages() {
|
||||
// PushTask push a new mj task in to task queue
|
||||
func (s *Service) PushTask(task types.MjTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push mj task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SyncTaskProgress 异步拉取任务
|
||||
|
||||
@@ -84,7 +84,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -99,8 +99,10 @@ func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
@@ -37,7 +38,7 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
return File{}, fmt.Errorf("error with get form: %v", err)
|
||||
}
|
||||
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, "")
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
|
||||
}
|
||||
@@ -57,13 +58,13 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
filename := filepath.Base(parse.Path)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with generate image dir: %v", err)
|
||||
}
|
||||
@@ -85,7 +86,7 @@ func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||
filePath, _ := utils.GenUploadPath(s.config.BasePath, "", ".png")
|
||||
err = os.WriteFile(filePath, imageData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing to file:%v", err)
|
||||
|
||||
@@ -44,7 +44,7 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -59,8 +59,10 @@ func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := filepath.Ext(parse.Path)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -86,7 +88,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}
|
||||
defer fileReader.Close()
|
||||
|
||||
fileExt := utils.GetImgExt(file.Filename)
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
|
||||
@@ -93,7 +93,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -108,8 +108,10 @@ func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
|
||||
@@ -23,7 +23,7 @@ type File struct {
|
||||
}
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (File, error)
|
||||
PutUrlFile(url string, useProxy bool) (string, error)
|
||||
PutUrlFile(url string, ext string, useProxy bool) (string, error)
|
||||
PutBase64(imageData string) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
||||
@@ -253,7 +253,9 @@ func (s *Service) checkTaskProgress(apiKey model.ApiKey) (*TaskProgressResp, err
|
||||
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push sd task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
|
||||
|
||||
@@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
|
||||
func (s *Service) PushTask(task types.SunoTask) {
|
||||
logger.Infof("add a new Suno task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push suno task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
@@ -270,14 +272,14 @@ func (s *Service) DownloadFiles() {
|
||||
for _, v := range items {
|
||||
// 下载图片和音频
|
||||
logger.Infof("try download cover image: %s", v.CoverURL)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, ".png", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download audio: %s", v.AudioURL)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, ".mp3", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
|
||||
@@ -51,7 +51,9 @@ func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Clien
|
||||
|
||||
func (s *Service) PushTask(task types.VideoTask) {
|
||||
logger.Infof("add a new Video task to the task list: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
if err := s.taskQueue.RPush(task); err != nil {
|
||||
logger.Errorf("push video task to queue failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
@@ -162,7 +164,7 @@ func (s *Service) DownloadFiles() {
|
||||
}
|
||||
|
||||
logger.Infof("try download video: %s", v.WaterURL)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, ".mp4", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
@@ -172,7 +174,7 @@ func (s *Service) DownloadFiles() {
|
||||
|
||||
if v.VideoURL != "" {
|
||||
logger.Infof("try download no water video: %s", v.VideoURL)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, ".mp4", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package model
|
||||
|
||||
type Config struct {
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Key string `gorm:"column:marker;type:varchar(20);uniqueIndex;not null;comment:标识" json:"marker"`
|
||||
Config string `gorm:"column:config_json;type:text;not null" json:"config_json"`
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `gorm:"column:name;type:varchar(20);uniqueIndex;not null;comment:配置名称"`
|
||||
Value string `gorm:"column:value;type:text;not null"`
|
||||
}
|
||||
|
||||
func (m *Config) TableName() string {
|
||||
|
||||
55
api/store/model/jimeng_job.go
Normal file
55
api/store/model/jimeng_job.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// JimengJob 即梦AI任务模型
|
||||
type JimengJob struct {
|
||||
Id uint `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserId uint `gorm:"column:user_id;type:int;not null;index;comment:用户ID" json:"user_id"`
|
||||
TaskId string `gorm:"column:task_id;type:varchar(100);not null;index;comment:任务ID" json:"task_id"`
|
||||
Type JMTaskType `gorm:"column:type;type:varchar(50);not null;comment:任务类型" json:"type"`
|
||||
ReqKey string `gorm:"column:req_key;type:varchar(100);comment:请求Key" json:"req_key"`
|
||||
Prompt string `gorm:"column:prompt;type:text;comment:提示词" json:"prompt"`
|
||||
TaskParams string `gorm:"column:task_params;type:text;comment:任务参数JSON" json:"task_params"`
|
||||
ImgURL string `gorm:"column:img_url;type:varchar(1024);comment:图片或封面URL" json:"img_url"`
|
||||
VideoURL string `gorm:"column:video_url;type:varchar(1024);comment:视频URL" json:"video_url"`
|
||||
RawData string `gorm:"column:raw_data;type:text;comment:原始API响应" json:"raw_data"`
|
||||
Progress int `gorm:"column:progress;type:int;default:0;comment:进度百分比" json:"progress"`
|
||||
Status JMTaskStatus `gorm:"column:status;type:varchar(20);default:'pending';comment:任务状态" json:"status"`
|
||||
ErrMsg string `gorm:"column:err_msg;type:varchar(1024);comment:错误信息" json:"err_msg"`
|
||||
Power int `gorm:"column:power;type:int(11);default:0;comment:消耗算力" json:"power"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:datetime;not null;comment:创建时间" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:datetime;not null;comment:更新时间" json:"updated_at"`
|
||||
}
|
||||
|
||||
// JMTaskStatus 任务状态
|
||||
type JMTaskStatus string
|
||||
|
||||
const (
|
||||
JMTaskStatusInQueue = JMTaskStatus("in_queue") // 任务已提交
|
||||
JMTaskStatusGenerating = JMTaskStatus("generating") // 任务处理中
|
||||
JMTaskStatusDone = JMTaskStatus("done") // 处理完成
|
||||
JMTaskStatusNotFound = JMTaskStatus("not_found") // 任务未找到
|
||||
JMTaskStatusSuccess = JMTaskStatus("success") // 任务成功
|
||||
JMTaskStatusFailed = JMTaskStatus("failed") // 任务失败
|
||||
JMTaskStatusExpired = JMTaskStatus("expired") // 任务过期
|
||||
)
|
||||
|
||||
// JMTaskType 任务类型
|
||||
type JMTaskType string
|
||||
|
||||
const (
|
||||
JMTaskTypeTextToImage = JMTaskType("text_to_image") // 文生图
|
||||
JMTaskTypeImageToImage = JMTaskType("image_to_image") // 图生图
|
||||
JMTaskTypeImageEdit = JMTaskType("image_edit") // 图像编辑
|
||||
JMTaskTypeImageEffects = JMTaskType("image_effects") // 图像特效
|
||||
JMTaskTypeTextToVideo = JMTaskType("text_to_video") // 文生视频
|
||||
JMTaskTypeImageToVideo = JMTaskType("image_to_video") // 图生视频
|
||||
)
|
||||
|
||||
// TableName 返回数据表名称
|
||||
func (JimengJob) TableName() string {
|
||||
return "chatgpt_jimeng_jobs"
|
||||
}
|
||||
@@ -10,6 +10,7 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"geekai/utils"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
@@ -23,15 +24,15 @@ func NewRedisQueue(name string, client *redis.Client) *RedisQueue {
|
||||
return &RedisQueue{name: name, client: client, ctx: context.Background()}
|
||||
}
|
||||
|
||||
func (q *RedisQueue) RPush(value interface{}) {
|
||||
q.client.RPush(q.ctx, q.name, utils.JsonEncode(value))
|
||||
func (q *RedisQueue) RPush(value any) error {
|
||||
return q.client.RPush(q.ctx, q.name, utils.JsonEncode(value)).Err()
|
||||
}
|
||||
|
||||
func (q *RedisQueue) LPush(value interface{}) {
|
||||
q.client.LPush(q.ctx, q.name, utils.JsonEncode(value))
|
||||
func (q *RedisQueue) LPush(value any) error {
|
||||
return q.client.LPush(q.ctx, q.name, utils.JsonEncode(value)).Err()
|
||||
}
|
||||
|
||||
func (q *RedisQueue) LPop(value interface{}) error {
|
||||
func (q *RedisQueue) LPop(value any) error {
|
||||
result, err := q.client.BLPop(q.ctx, 0, q.name).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -39,10 +40,18 @@ func (q *RedisQueue) LPop(value interface{}) error {
|
||||
return utils.JsonDecode(result[1], value)
|
||||
}
|
||||
|
||||
func (q *RedisQueue) RPop(value interface{}) error {
|
||||
func (q *RedisQueue) RPop(value any) error {
|
||||
result, err := q.client.BRPop(q.ctx, 0, q.name).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return utils.JsonDecode(result[1], value)
|
||||
}
|
||||
|
||||
func (q *RedisQueue) Size() (int64, error) {
|
||||
return q.client.LLen(q.ctx, q.name).Result()
|
||||
}
|
||||
|
||||
func (q *RedisQueue) Clear() error {
|
||||
return q.client.Del(q.ctx, q.name).Err()
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package vo
|
||||
|
||||
import "geekai/core/types"
|
||||
|
||||
type Config struct {
|
||||
Id uint `json:"id"`
|
||||
Key string `json:"key"`
|
||||
SystemConfig types.SystemConfig `json:"system_config"`
|
||||
}
|
||||
23
api/store/vo/jimeng_job.go
Normal file
23
api/store/vo/jimeng_job.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package vo
|
||||
|
||||
import "geekai/store/model"
|
||||
|
||||
// JimengJob 即梦AI任务VO
|
||||
type JimengJob struct {
|
||||
Id uint `json:"id"`
|
||||
UserId uint `json:"user_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
Type model.JMTaskType `json:"type"`
|
||||
ReqKey string `json:"req_key"`
|
||||
Prompt string `json:"prompt"`
|
||||
TaskParams string `json:"task_params"`
|
||||
ImgURL string `json:"img_url"`
|
||||
VideoURL string `json:"video_url"`
|
||||
RawData string `json:"raw_data"`
|
||||
Progress int `json:"progress"`
|
||||
Status model.JMTaskStatus `json:"status"`
|
||||
ErrMsg string `json:"err_msg"`
|
||||
Power int `json:"power"`
|
||||
CreatedAt int64 `json:"created_at"` // 时间戳
|
||||
UpdatedAt int64 `json:"updated_at"` // 时间戳
|
||||
}
|
||||
16
api/test/app_test.go
Normal file
16
api/test/app_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"geekai/utils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewService 测试创建爬虫服务
|
||||
func TestNewService(t *testing.T) {
|
||||
videoURL := `https://p3-aiop-sign.byteimg.com/tos-cn-i-vuqhorh59i/2025072310444223AAB2C93CE2B9BB8573-6843-0~tplv-vuqhorh59i-image.image?rk3s=7f9e702d&x-expires=1753325083&x-signature=%2F5V3H%2FWPQlOej6VtVZyf%2BNJBWok%3D`
|
||||
filePath := "test_video.png"
|
||||
err := utils.DownloadFile(videoURL, filePath, "")
|
||||
if err != nil {
|
||||
t.Fatalf("下载视频失败: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"geekai/service/crawler"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewService 测试创建爬虫服务
|
||||
func TestNewService(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("测试过程中发生崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
service, err := crawler.NewService()
|
||||
if err != nil {
|
||||
t.Logf("注意: 创建爬虫服务失败,可能是因为Chrome浏览器未安装: %v", err)
|
||||
t.Skip("跳过测试 - 浏览器问题")
|
||||
return
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 创建服务成功则测试通过
|
||||
if service == nil {
|
||||
t.Fatal("创建的爬虫服务为空")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchWeb 测试网络搜索功能
|
||||
func TestSearchWeb(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("测试过程中发生崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// 设置测试超时时间
|
||||
timeout := time.After(600 * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("搜索过程中发生崩溃: %v", r)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
keyword := "Golang编程"
|
||||
maxPages := 1
|
||||
|
||||
// 执行搜索
|
||||
result, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
t.Logf("搜索失败,可能是网络问题或浏览器未安装: %v", err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果不为空
|
||||
if result == "" {
|
||||
t.Log("搜索结果为空")
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果包含关键字或部分关键字
|
||||
if !strings.Contains(result, "Golang") && !strings.Contains(result, "golang") {
|
||||
t.Logf("搜索结果中未包含关键字或部分关键字,获取到的结果: %s", result)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果格式,至少应包含"链接:"
|
||||
if !strings.Contains(result, "链接:") {
|
||||
t.Log("搜索结果格式不正确,没有找到'链接:'部分")
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
done <- true
|
||||
t.Logf("搜索结果: %s", result)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Log("测试超时 - 这可能是正常的,特别是在网络较慢或资源有限的环境中")
|
||||
t.Skip("跳过测试 - 超时")
|
||||
case success := <-done:
|
||||
if !success {
|
||||
t.Skip("跳过测试 - 搜索失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 减少测试用例数量,只保留基本测试
|
||||
// 这样可以减少测试时间和资源消耗
|
||||
// 以下测试用例被注释掉,可以根据需要启用
|
||||
|
||||
/*
|
||||
// TestSearchWebNoResults 测试搜索无结果的情况
|
||||
func TestSearchWebNoResults(t *testing.T) {
|
||||
// 设置测试超时时间
|
||||
timeout := time.After(60 * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
// 使用一个极不可能有搜索结果的随机字符串
|
||||
keyword := "askdjfhalskjdfhas98y234hlakjsdhflakjshdflakjshdfl"
|
||||
maxPages := 1
|
||||
|
||||
// 执行搜索
|
||||
result, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
t.Errorf("搜索失败: %v", err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果为"未找到相关搜索结果"
|
||||
if !strings.Contains(result, "未找到") && !strings.Contains(result, "0 条搜索结果") {
|
||||
t.Errorf("对于无结果的搜索,预期返回包含'未找到'的信息,实际返回: %s", result)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("测试超时")
|
||||
case success := <-done:
|
||||
if !success {
|
||||
t.Fatal("测试失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchWebMultiplePages 测试多页搜索
|
||||
func TestSearchWebMultiplePages(t *testing.T) {
|
||||
// 设置测试超时时间
|
||||
timeout := time.After(120 * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
keyword := "golang programming"
|
||||
maxPages := 2
|
||||
|
||||
// 执行搜索
|
||||
result, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
t.Errorf("搜索失败: %v", err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果不为空
|
||||
if result == "" {
|
||||
t.Error("搜索结果为空")
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 计算结果中的条目数
|
||||
resultCount := strings.Count(result, "链接:")
|
||||
if resultCount < 10 {
|
||||
t.Errorf("多页搜索应返回至少10条结果,实际返回: %d", resultCount)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("测试超时")
|
||||
case success := <-done:
|
||||
if !success {
|
||||
t.Fatal("测试失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchWebWithMaxPageLimit 测试页数限制
|
||||
func TestSearchWebWithMaxPageLimit(t *testing.T) {
|
||||
service, err := crawler.NewService()
|
||||
if err != nil {
|
||||
t.Fatalf("创建爬虫服务失败: %v", err)
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 传入一个超过限制的页数
|
||||
results, err := service.WebSearch("golang", 15)
|
||||
if err != nil {
|
||||
t.Fatalf("搜索失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证结果不为空
|
||||
if len(results) == 0 {
|
||||
t.Fatal("搜索结果为空")
|
||||
}
|
||||
|
||||
// 因为最大页数限制为10,所以结果数量应该小于等于10*10=100
|
||||
if len(results) > 100 {
|
||||
t.Errorf("搜索结果超过最大限制,预期最多100条,实际: %d", len(results))
|
||||
}
|
||||
}
|
||||
*/
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 显示执行的命令
|
||||
set -x
|
||||
|
||||
# 检查Chrome/Chromium浏览器是否已安装
|
||||
check_chrome() {
|
||||
echo "检查Chrome/Chromium浏览器是否安装..."
|
||||
which chromium-browser || which google-chrome || which chromium
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "警告: 未找到Chrome或Chromium浏览器,测试可能会失败"
|
||||
echo "尝试安装必要的依赖..."
|
||||
sudo apt-get update && sudo apt-get install -y libnss3 libgbm1 libasound2 libatk1.0-0 libatk-bridge2.0-0 libcups2 libxkbcommon0 libxdamage1 libxfixes3 libxrandr2 libxcomposite1 libxcursor1 libxi6 libxtst6 libnss3 libnspr4 libpango1.0-0
|
||||
echo "已安装依赖,但仍需安装Chrome/Chromium浏览器以完全支持测试"
|
||||
else
|
||||
echo "已找到Chrome/Chromium浏览器"
|
||||
fi
|
||||
}
|
||||
|
||||
# 切换到项目根目录
|
||||
cd ..
|
||||
|
||||
# 检查环境
|
||||
check_chrome
|
||||
|
||||
# 运行爬虫测试,使用超时限制
|
||||
echo "开始运行爬虫测试..."
|
||||
timeout 180s go test -v ./test/crawler_test.go -run "TestNewService|TestSearchWeb"
|
||||
TEST_RESULT=$?
|
||||
|
||||
if [ $TEST_RESULT -eq 124 ]; then
|
||||
echo "测试超时终止"
|
||||
exit 1
|
||||
elif [ $TEST_RESULT -ne 0 ]; then
|
||||
echo "测试失败,退出码: $TEST_RESULT"
|
||||
exit $TEST_RESULT
|
||||
else
|
||||
echo "测试成功完成"
|
||||
fi
|
||||
|
||||
echo "测试完成"
|
||||
@@ -1,55 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
codeLength = 32 // 兑换码长度
|
||||
)
|
||||
|
||||
var (
|
||||
codeMap = make(map[string]bool)
|
||||
mapMutex = &sync.Mutex{}
|
||||
)
|
||||
|
||||
// GenerateUniqueCode 生成唯一兑换码
|
||||
func GenerateUniqueCode() (string, error) {
|
||||
for {
|
||||
code, err := generateCode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mapMutex.Lock()
|
||||
if !codeMap[code] {
|
||||
codeMap[code] = true
|
||||
mapMutex.Unlock()
|
||||
return code, nil
|
||||
}
|
||||
mapMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// generateCode 生成兑换码
|
||||
func generateCode() (string, error) {
|
||||
bytes := make([]byte, codeLength/2) // 因为 hex 编码会使长度翻倍
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
for i := 0; i < 10; i++ {
|
||||
code, err := GenerateUniqueCode()
|
||||
if err != nil {
|
||||
fmt.Println("Error generating code:", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("Generated code:", code)
|
||||
}
|
||||
}
|
||||
@@ -11,9 +11,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/skip2/go-qrcode"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
@@ -22,11 +19,22 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||
"github.com/nfnt/resize"
|
||||
"github.com/skip2/go-qrcode"
|
||||
)
|
||||
|
||||
// CopyObject 拷贝对象
|
||||
func CopyObject(src interface{}, dst interface{}) error {
|
||||
|
||||
// 这里做异常处理
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Errorf("copy object failed: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
srcType := reflect.TypeOf(src)
|
||||
srcValue := reflect.ValueOf(src)
|
||||
dstValue := reflect.ValueOf(dst).Elem()
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
// GenUploadPath 生成上传文件路径
|
||||
func GenUploadPath(basePath, filename string, isImg bool) (string, error) {
|
||||
func GenUploadPath(basePath, filename string, ext string) (string, error) {
|
||||
now := time.Now()
|
||||
dir := fmt.Sprintf("%s/%d/%d", basePath, now.Year(), now.Month())
|
||||
_, err := os.Stat(dir)
|
||||
@@ -30,13 +30,11 @@ func GenUploadPath(basePath, filename string, isImg bool) (string, error) {
|
||||
return "", fmt.Errorf("error with create upload dir:%v", err)
|
||||
}
|
||||
}
|
||||
var fileExt string
|
||||
if isImg {
|
||||
fileExt = GetImgExt(filename)
|
||||
} else {
|
||||
fileExt = filepath.Ext(filename)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(filename)
|
||||
}
|
||||
return fmt.Sprintf("%s/%d%s", dir, now.UnixMicro(), fileExt), nil
|
||||
|
||||
return fmt.Sprintf("%s/%d%s", dir, now.UnixMicro(), ext), nil
|
||||
}
|
||||
|
||||
// GenUploadUrl 生成上传文件 URL
|
||||
@@ -80,14 +78,6 @@ func DownloadFile(fileURL string, filepath string, proxy string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetImgExt(filename string) string {
|
||||
ext := filepath.Ext(filename)
|
||||
if ext == "" {
|
||||
return ".png"
|
||||
}
|
||||
return ext
|
||||
}
|
||||
|
||||
func ExtractImgURLs(text string) []string {
|
||||
re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`)
|
||||
matches := re.FindAllStringSubmatch(text, 10)
|
||||
|
||||
Reference in New Issue
Block a user