refactor midjourney service, use api key in database

This commit is contained in:
RockYang
2024-08-06 18:30:57 +08:00
parent cc551ba266
commit f9b809801d
29 changed files with 585 additions and 1203 deletions

View File

@@ -10,6 +10,7 @@ package sd
import (
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/service/oss"
"geekai/store"
"geekai/store/model"
@@ -79,7 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message NotifyMessage
var message service.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue

View File

@@ -10,6 +10,7 @@ package sd
import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/store"
@@ -22,6 +23,8 @@ import (
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// SD 绘画服务
type Service struct {
@@ -87,11 +90,11 @@ func (s *Service) Run() {
logger.Error("绘画任务执行失败:", err.Error())
// update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": 101,
"progress": service.FailTaskProgress,
"err_msg": err.Error(),
})
// 通知前端,任务失败
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue
}
}
@@ -206,7 +209,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId)
return nil
@@ -216,7 +219,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running})
s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据
if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)

View File

@@ -1,24 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import logger2 "geekai/logger"
var logger = logger2.GetLogger()
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)