mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-24 03:54:26 +08:00
文生视频和图生视频功能完成
This commit is contained in:
@@ -290,7 +290,7 @@ func (h *JimengHandler) Jobs(c *gin.Context) {
|
||||
|
||||
// 分页查询
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(req.PageSize).Find(&jobs).Error; err != nil {
|
||||
if err := query.Order("updated_at DESC").Offset(offset).Limit(req.PageSize).Find(&jobs).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -338,22 +338,32 @@ func (h *JimengHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.jimengService.DeleteJob(uint(jobId), user.Id); err != nil {
|
||||
tx := h.DB.Begin()
|
||||
if err := tx.Where("id = ? AND user_id = ?", jobId, user.Id).Delete(&model.JimengJob{}).Error; err != nil {
|
||||
logger.Errorf("delete jimeng job failed: %v", err)
|
||||
resp.ERROR(c, "删除任务失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 退回算力
|
||||
err = h.userService.IncreasePower(user.Id, job.Power, model.PowerLog{
|
||||
Type: types.PowerRefund,
|
||||
Model: "jimeng",
|
||||
Remark: fmt.Sprintf("删除任务,退回%d算力", job.Power),
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, "退回算力失败")
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
resp.SUCCESS(c, gin.H{})
|
||||
}
|
||||
|
||||
// Retry 重试任务
|
||||
func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
user, err := h.GetLoginUser(c)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
userId := h.GetLoginUserId(c)
|
||||
|
||||
jobId := h.GetInt(c, "id", 0)
|
||||
if jobId == 0 {
|
||||
@@ -368,7 +378,7 @@ func (h *JimengHandler) Retry(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if job.UserId != user.Id {
|
||||
if job.UserId != userId {
|
||||
resp.ERROR(c, "无权限操作")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -144,7 +144,15 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 使用http.Get下载文件
|
||||
r, err := http.Get(fileUrl)
|
||||
req, err := http.NewRequest("GET", fileUrl, nil)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
// 模拟浏览器 UA
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
|
||||
client := &http.Client{}
|
||||
r, err := client.Do(req)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
@@ -157,6 +165,5 @@ func (h *NetHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
// 将下载的文件内容写入响应
|
||||
_, _ = io.Copy(c.Writer, r.Body)
|
||||
}
|
||||
|
||||
@@ -293,7 +293,7 @@ func (s *Service) DownloadImages() {
|
||||
|
||||
func (s *Service) downloadImage(jobId uint, orgURL string) (string, error) {
|
||||
// sava image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, false)
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(orgURL, ".png", false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
"geekai/service/oss"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
@@ -31,10 +32,11 @@ type Service struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
running bool
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
@@ -54,6 +56,7 @@ func NewService(db *gorm.DB, redisCli *redis.Client) *Service {
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: false,
|
||||
uploader: uploader,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +68,7 @@ func (s *Service) Start() {
|
||||
logger.Info("Starting Jimeng service and task consumer...")
|
||||
s.running = true
|
||||
go s.consumeTasks()
|
||||
go s.pollTaskStatus()
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
@@ -166,6 +170,8 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
return s.handleTaskError(job.Id, fmt.Sprintf("build task request failed: %v", err))
|
||||
}
|
||||
|
||||
logger.Infof("提交即梦任务: %+v", req)
|
||||
|
||||
// 提交异步任务
|
||||
resp, err := s.client.SubmitTask(req)
|
||||
if err != nil {
|
||||
@@ -186,8 +192,7 @@ func (s *Service) ProcessTask(jobId uint) error {
|
||||
logger.Errorf("update jimeng job task_id failed: %v", err)
|
||||
}
|
||||
|
||||
// 开始轮询任务状态
|
||||
return s.pollTaskStatus(job.Id, resp.Data.TaskId, job.ReqKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildTaskRequest 构建任务请求(统一的参数解析)
|
||||
@@ -360,78 +365,100 @@ func (s *Service) setCommonParams(req *SubmitTaskRequest, params map[string]any)
|
||||
}
|
||||
|
||||
// pollTaskStatus 轮询任务状态
|
||||
func (s *Service) pollTaskStatus(jobId uint, taskId, reqKey string) error {
|
||||
maxRetries := 60 // 最大重试次数,60次 * 5秒 = 5分钟
|
||||
retryCount := 0
|
||||
func (s *Service) pollTaskStatus() {
|
||||
|
||||
for retryCount < maxRetries {
|
||||
time.Sleep(5 * time.Second) // 等待5秒
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: reqKey,
|
||||
TaskId: taskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("query jimeng task status failed: %v", err)
|
||||
retryCount++
|
||||
for {
|
||||
var jobs []model.JimengJob
|
||||
s.db.Where("status IN (?)", []model.JMTaskStatus{model.JMTaskStatusGenerating, model.JMTaskStatusInQueue}).Find(&jobs)
|
||||
if len(jobs) == 0 {
|
||||
logger.Debugf("no jimeng task to poll, sleep 10s")
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Update("raw_data", string(rawData))
|
||||
for _, job := range jobs {
|
||||
// 任务超时处理
|
||||
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
|
||||
s.handleTaskError(job.Id, "task timeout")
|
||||
continue
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
resp, err := s.client.QueryTask(&QueryTaskRequest{
|
||||
ReqKey: job.ReqKey,
|
||||
TaskId: job.TaskId,
|
||||
ReqJson: `{"return_url":true}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("query jimeng task status failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新原始数据
|
||||
rawData, _ := json.Marshal(resp)
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Update("raw_data", string(rawData))
|
||||
|
||||
if resp.Code != 10000 {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
continue
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
continue
|
||||
}
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 设置结果URL
|
||||
if len(resp.Data.ImageUrls) > 0 {
|
||||
imgUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.ImageUrls[0], ".png", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload image failed: %v", err)
|
||||
imgUrl = resp.Data.ImageUrls[0]
|
||||
}
|
||||
updates["img_url"] = imgUrl
|
||||
}
|
||||
if resp.Data.VideoUrl != "" {
|
||||
videoUrl, err := s.uploader.GetUploadHandler().PutUrlFile(resp.Data.VideoUrl, ".mp4", false)
|
||||
if err != nil {
|
||||
logger.Errorf("upload video failed: %v", err)
|
||||
videoUrl = resp.Data.VideoUrl
|
||||
}
|
||||
updates["video_url"] = videoUrl
|
||||
}
|
||||
|
||||
s.db.Model(&model.JimengJob{}).Where("id = ?", job.Id).Updates(updates)
|
||||
case model.JMTaskStatusInQueue, model.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(job.Id, model.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
// 任务未找到
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
// 任务过期
|
||||
s.handleTaskError(job.Id, "task expired")
|
||||
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
if resp.Code != 10000 {
|
||||
return s.handleTaskError(jobId, fmt.Sprintf("query task failed: %s", resp.Message))
|
||||
}
|
||||
|
||||
switch resp.Data.Status {
|
||||
case model.JMTaskStatusDone:
|
||||
// 判断任务是否成功
|
||||
if resp.Message != "Success" {
|
||||
return s.handleTaskError(jobId, fmt.Sprintf("task failed: %s", resp.Data.AlgorithmBaseResp.StatusMessage))
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// 任务完成,更新结果
|
||||
updates := map[string]any{
|
||||
"status": model.JMTaskStatusSuccess,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
// 设置结果URL
|
||||
if len(resp.Data.ImageUrls) > 0 {
|
||||
updates["img_url"] = resp.Data.ImageUrls[0]
|
||||
}
|
||||
if resp.Data.VideoUrl != "" {
|
||||
updates["video_url"] = resp.Data.VideoUrl
|
||||
}
|
||||
|
||||
return s.db.Model(&model.JimengJob{}).Where("id = ?", jobId).Updates(updates).Error
|
||||
|
||||
case model.JMTaskStatusInQueue:
|
||||
// 任务在队列中
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusGenerating:
|
||||
// 任务处理中
|
||||
s.UpdateJobStatus(jobId, model.JMTaskStatusGenerating, "")
|
||||
|
||||
case model.JMTaskStatusNotFound:
|
||||
// 任务未找到或已过期
|
||||
return s.handleTaskError(jobId, resp.Message)
|
||||
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
|
||||
retryCount++
|
||||
}
|
||||
|
||||
// 超时处理
|
||||
return s.handleTaskError(jobId, "task timeout")
|
||||
}
|
||||
|
||||
// UpdateJobStatus 更新任务状态
|
||||
@@ -498,11 +525,6 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// DeleteJob 删除任务
|
||||
func (s *Service) DeleteJob(jobId uint, userId uint) error {
|
||||
return s.db.Where("id = ? AND user_id = ?", jobId, userId).Delete(&model.JimengJob{}).Error
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
@@ -191,7 +191,7 @@ func (s *Service) DownloadImages() {
|
||||
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
|
||||
proxy = true
|
||||
}
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
|
||||
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, ".png", proxy)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
|
||||
|
||||
@@ -84,7 +84,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -99,8 +99,10 @@ func (s AliYunOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
@@ -37,7 +38,7 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
return File{}, fmt.Errorf("error with get form: %v", err)
|
||||
}
|
||||
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
|
||||
path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, "")
|
||||
if err != nil {
|
||||
return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
|
||||
}
|
||||
@@ -57,13 +58,13 @@ func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s LocalStorage) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
parse, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
filename := filepath.Base(parse.Path)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, filename, ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with generate image dir: %v", err)
|
||||
}
|
||||
@@ -85,7 +86,7 @@ func (s LocalStorage) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
|
||||
filePath, _ := utils.GenUploadPath(s.config.BasePath, "", ".png")
|
||||
err = os.WriteFile(filePath, imageData, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing to file:%v", err)
|
||||
|
||||
@@ -44,7 +44,7 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -59,8 +59,10 @@ func (s MiniOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := filepath.Ext(parse.Path)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -86,7 +88,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
}
|
||||
defer fileReader.Close()
|
||||
|
||||
fileExt := utils.GetImgExt(file.Filename)
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
|
||||
@@ -93,7 +93,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -108,8 +108,10 @@ func (s QinNiuOss) PutUrlFile(fileURL string, useProxy bool) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with parse image URL: %v", err)
|
||||
}
|
||||
fileExt := utils.GetImgExt(parse.Path)
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
|
||||
@@ -23,7 +23,7 @@ type File struct {
|
||||
}
|
||||
type Uploader interface {
|
||||
PutFile(ctx *gin.Context, name string) (File, error)
|
||||
PutUrlFile(url string, useProxy bool) (string, error)
|
||||
PutUrlFile(url string, ext string, useProxy bool) (string, error)
|
||||
PutBase64(imageData string) (string, error)
|
||||
Delete(fileURL string) error
|
||||
}
|
||||
|
||||
@@ -272,14 +272,14 @@ func (s *Service) DownloadFiles() {
|
||||
for _, v := range items {
|
||||
// 下载图片和音频
|
||||
logger.Infof("try download cover image: %s", v.CoverURL)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, true)
|
||||
coverURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.CoverURL, ".png", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download image with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("try download audio: %s", v.AudioURL)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, true)
|
||||
audioURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.AudioURL, ".mp3", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download audio with error: %v", err)
|
||||
continue
|
||||
|
||||
@@ -164,7 +164,7 @@ func (s *Service) DownloadFiles() {
|
||||
}
|
||||
|
||||
logger.Infof("try download video: %s", v.WaterURL)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, true)
|
||||
videoURL, err := s.uploadManager.GetUploadHandler().PutUrlFile(v.WaterURL, ".mp4", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
@@ -174,7 +174,7 @@ func (s *Service) DownloadFiles() {
|
||||
|
||||
if v.VideoURL != "" {
|
||||
logger.Infof("try download no water video: %s", v.VideoURL)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, true)
|
||||
videoURL, err = s.uploadManager.GetUploadHandler().PutUrlFile(v.VideoURL, ".mp4", true)
|
||||
if err != nil {
|
||||
logger.Errorf("download video with error: %v", err)
|
||||
continue
|
||||
|
||||
@@ -34,6 +34,7 @@ const (
|
||||
JMTaskStatusNotFound = JMTaskStatus("not_found") // 任务未找到
|
||||
JMTaskStatusSuccess = JMTaskStatus("success") // 任务成功
|
||||
JMTaskStatusFailed = JMTaskStatus("failed") // 任务失败
|
||||
JMTaskStatusExpired = JMTaskStatus("expired") // 任务过期
|
||||
)
|
||||
|
||||
// JMTaskType 任务类型
|
||||
|
||||
16
api/test/app_test.go
Normal file
16
api/test/app_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"geekai/utils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewService 测试创建爬虫服务
|
||||
func TestNewService(t *testing.T) {
|
||||
videoURL := `https://p3-aiop-sign.byteimg.com/tos-cn-i-vuqhorh59i/2025072310444223AAB2C93CE2B9BB8573-6843-0~tplv-vuqhorh59i-image.image?rk3s=7f9e702d&x-expires=1753325083&x-signature=%2F5V3H%2FWPQlOej6VtVZyf%2BNJBWok%3D`
|
||||
filePath := "test_video.png"
|
||||
err := utils.DownloadFile(videoURL, filePath, "")
|
||||
if err != nil {
|
||||
t.Fatalf("下载视频失败: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"geekai/service/crawler"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewService 测试创建爬虫服务
|
||||
func TestNewService(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("测试过程中发生崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
service, err := crawler.NewService()
|
||||
if err != nil {
|
||||
t.Logf("注意: 创建爬虫服务失败,可能是因为Chrome浏览器未安装: %v", err)
|
||||
t.Skip("跳过测试 - 浏览器问题")
|
||||
return
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 创建服务成功则测试通过
|
||||
if service == nil {
|
||||
t.Fatal("创建的爬虫服务为空")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchWeb 测试网络搜索功能
|
||||
func TestSearchWeb(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("测试过程中发生崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// 设置测试超时时间
|
||||
timeout := time.After(600 * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("搜索过程中发生崩溃: %v", r)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
keyword := "Golang编程"
|
||||
maxPages := 1
|
||||
|
||||
// 执行搜索
|
||||
result, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
t.Logf("搜索失败,可能是网络问题或浏览器未安装: %v", err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果不为空
|
||||
if result == "" {
|
||||
t.Log("搜索结果为空")
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果包含关键字或部分关键字
|
||||
if !strings.Contains(result, "Golang") && !strings.Contains(result, "golang") {
|
||||
t.Logf("搜索结果中未包含关键字或部分关键字,获取到的结果: %s", result)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果格式,至少应包含"链接:"
|
||||
if !strings.Contains(result, "链接:") {
|
||||
t.Log("搜索结果格式不正确,没有找到'链接:'部分")
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
done <- true
|
||||
t.Logf("搜索结果: %s", result)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Log("测试超时 - 这可能是正常的,特别是在网络较慢或资源有限的环境中")
|
||||
t.Skip("跳过测试 - 超时")
|
||||
case success := <-done:
|
||||
if !success {
|
||||
t.Skip("跳过测试 - 搜索失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 减少测试用例数量,只保留基本测试
|
||||
// 这样可以减少测试时间和资源消耗
|
||||
// 以下测试用例被注释掉,可以根据需要启用
|
||||
|
||||
/*
|
||||
// TestSearchWebNoResults 测试搜索无结果的情况
|
||||
func TestSearchWebNoResults(t *testing.T) {
|
||||
// 设置测试超时时间
|
||||
timeout := time.After(60 * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
// 使用一个极不可能有搜索结果的随机字符串
|
||||
keyword := "askdjfhalskjdfhas98y234hlakjsdhflakjshdflakjshdfl"
|
||||
maxPages := 1
|
||||
|
||||
// 执行搜索
|
||||
result, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
t.Errorf("搜索失败: %v", err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果为"未找到相关搜索结果"
|
||||
if !strings.Contains(result, "未找到") && !strings.Contains(result, "0 条搜索结果") {
|
||||
t.Errorf("对于无结果的搜索,预期返回包含'未找到'的信息,实际返回: %s", result)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("测试超时")
|
||||
case success := <-done:
|
||||
if !success {
|
||||
t.Fatal("测试失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchWebMultiplePages 测试多页搜索
|
||||
func TestSearchWebMultiplePages(t *testing.T) {
|
||||
// 设置测试超时时间
|
||||
timeout := time.After(120 * time.Second)
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
keyword := "golang programming"
|
||||
maxPages := 2
|
||||
|
||||
// 执行搜索
|
||||
result, err := crawler.SearchWeb(keyword, maxPages)
|
||||
if err != nil {
|
||||
t.Errorf("搜索失败: %v", err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 验证结果不为空
|
||||
if result == "" {
|
||||
t.Error("搜索结果为空")
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
// 计算结果中的条目数
|
||||
resultCount := strings.Count(result, "链接:")
|
||||
if resultCount < 10 {
|
||||
t.Errorf("多页搜索应返回至少10条结果,实际返回: %d", resultCount)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("测试超时")
|
||||
case success := <-done:
|
||||
if !success {
|
||||
t.Fatal("测试失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchWebWithMaxPageLimit 测试页数限制
|
||||
func TestSearchWebWithMaxPageLimit(t *testing.T) {
|
||||
service, err := crawler.NewService()
|
||||
if err != nil {
|
||||
t.Fatalf("创建爬虫服务失败: %v", err)
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 传入一个超过限制的页数
|
||||
results, err := service.WebSearch("golang", 15)
|
||||
if err != nil {
|
||||
t.Fatalf("搜索失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证结果不为空
|
||||
if len(results) == 0 {
|
||||
t.Fatal("搜索结果为空")
|
||||
}
|
||||
|
||||
// 因为最大页数限制为10,所以结果数量应该小于等于10*10=100
|
||||
if len(results) > 100 {
|
||||
t.Errorf("搜索结果超过最大限制,预期最多100条,实际: %d", len(results))
|
||||
}
|
||||
}
|
||||
*/
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 显示执行的命令
|
||||
set -x
|
||||
|
||||
# 检查Chrome/Chromium浏览器是否已安装
|
||||
check_chrome() {
|
||||
echo "检查Chrome/Chromium浏览器是否安装..."
|
||||
which chromium-browser || which google-chrome || which chromium
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "警告: 未找到Chrome或Chromium浏览器,测试可能会失败"
|
||||
echo "尝试安装必要的依赖..."
|
||||
sudo apt-get update && sudo apt-get install -y libnss3 libgbm1 libasound2 libatk1.0-0 libatk-bridge2.0-0 libcups2 libxkbcommon0 libxdamage1 libxfixes3 libxrandr2 libxcomposite1 libxcursor1 libxi6 libxtst6 libnss3 libnspr4 libpango1.0-0
|
||||
echo "已安装依赖,但仍需安装Chrome/Chromium浏览器以完全支持测试"
|
||||
else
|
||||
echo "已找到Chrome/Chromium浏览器"
|
||||
fi
|
||||
}
|
||||
|
||||
# 切换到项目根目录
|
||||
cd ..
|
||||
|
||||
# 检查环境
|
||||
check_chrome
|
||||
|
||||
# 运行爬虫测试,使用超时限制
|
||||
echo "开始运行爬虫测试..."
|
||||
timeout 180s go test -v ./test/crawler_test.go -run "TestNewService|TestSearchWeb"
|
||||
TEST_RESULT=$?
|
||||
|
||||
if [ $TEST_RESULT -eq 124 ]; then
|
||||
echo "测试超时终止"
|
||||
exit 1
|
||||
elif [ $TEST_RESULT -ne 0 ]; then
|
||||
echo "测试失败,退出码: $TEST_RESULT"
|
||||
exit $TEST_RESULT
|
||||
else
|
||||
echo "测试成功完成"
|
||||
fi
|
||||
|
||||
echo "测试完成"
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
// GenUploadPath 生成上传文件路径
|
||||
func GenUploadPath(basePath, filename string, isImg bool) (string, error) {
|
||||
func GenUploadPath(basePath, filename string, ext string) (string, error) {
|
||||
now := time.Now()
|
||||
dir := fmt.Sprintf("%s/%d/%d", basePath, now.Year(), now.Month())
|
||||
_, err := os.Stat(dir)
|
||||
@@ -30,13 +30,11 @@ func GenUploadPath(basePath, filename string, isImg bool) (string, error) {
|
||||
return "", fmt.Errorf("error with create upload dir:%v", err)
|
||||
}
|
||||
}
|
||||
var fileExt string
|
||||
if isImg {
|
||||
fileExt = GetImgExt(filename)
|
||||
} else {
|
||||
fileExt = filepath.Ext(filename)
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(filename)
|
||||
}
|
||||
return fmt.Sprintf("%s/%d%s", dir, now.UnixMicro(), fileExt), nil
|
||||
|
||||
return fmt.Sprintf("%s/%d%s", dir, now.UnixMicro(), ext), nil
|
||||
}
|
||||
|
||||
// GenUploadUrl 生成上传文件 URL
|
||||
@@ -80,14 +78,6 @@ func DownloadFile(fileURL string, filepath string, proxy string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetImgExt(filename string) string {
|
||||
ext := filepath.Ext(filename)
|
||||
if ext == "" {
|
||||
return ".png"
|
||||
}
|
||||
return ext
|
||||
}
|
||||
|
||||
func ExtractImgURLs(text string) []string {
|
||||
re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`)
|
||||
matches := re.FindAllStringSubmatch(text, 10)
|
||||
|
||||
Reference in New Issue
Block a user