mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-17 16:56:38 +08:00
feat: add implements for stable diffusion service
This commit is contained in:
parent
c1143d7a6d
commit
d51a724ade
70
api/core/types/task.go
Normal file
70
api/core/types/task.go
Normal file
@ -0,0 +1,70 @@
|
||||
package types
|
||||
|
||||
// TaskType 任务类别
|
||||
type TaskType string
|
||||
|
||||
func (t TaskType) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
const (
|
||||
TaskImage = TaskType("image")
|
||||
TaskUpscale = TaskType("upscale")
|
||||
TaskVariation = TaskType("variation")
|
||||
TaskTxt2Img = TaskType("text2img")
|
||||
)
|
||||
|
||||
// TaskSrc 任务来源
|
||||
type TaskSrc string
|
||||
|
||||
const (
|
||||
TaskSrcChat = TaskSrc("chat") // 来自聊天页面
|
||||
TaskSrcImg = TaskSrc("img") // 专业绘画页面
|
||||
)
|
||||
|
||||
// MjTask MidJourney 任务
|
||||
type MjTask struct {
|
||||
Id int `json:"id"`
|
||||
SessionId string `json:"session_id"`
|
||||
Src TaskSrc `json:"src"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
ChatId string `json:"chat_id,omitempty"`
|
||||
RoleId int `json:"role_id,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Index int32 `json:"index,omitempty"`
|
||||
MessageId string `json:"message_id,omitempty"`
|
||||
MessageHash string `json:"message_hash,omitempty"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
}
|
||||
|
||||
// SdParams stable diffusion 绘画参数
|
||||
type SdParams struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt"`
|
||||
Steps int `json:"steps"`
|
||||
Sampler string `json:"sampler"`
|
||||
FaceFix bool `json:"face_fix"`
|
||||
CfgScale float32 `json:"cfg_scale"`
|
||||
Seed int64 `json:"seed"`
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
HdFix bool `json:"hd_fix"`
|
||||
HdRedrawRate float32 `json:"hd_redraw_rate"`
|
||||
HdScale int `json:"hd_scale"`
|
||||
HdScaleAlg string `json:"hd_scale_alg"`
|
||||
HdSampleNum int `json:"hd_sample_num"`
|
||||
}
|
||||
|
||||
type SdTask struct {
|
||||
Id int `json:"id"`
|
||||
SessionId string `json:"session_id"`
|
||||
Src types.TaskSrc `json:"src"`
|
||||
Type types.TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Params types.SdParams `json:"params"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
}
|
@ -66,7 +66,7 @@ func NewMidJourneyHandler(
|
||||
return &h
|
||||
}
|
||||
|
||||
type notifyData struct {
|
||||
type mjNotifyData struct {
|
||||
MessageId string `json:"message_id"`
|
||||
ReferenceId string `json:"reference_id"`
|
||||
Image Image `json:"image"`
|
||||
@ -98,7 +98,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
var data notifyData
|
||||
var data mjNotifyData
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
@ -122,14 +122,14 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) {
|
||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
var task service.MjTask
|
||||
var task types.MjTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
@ -143,7 +143,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if task.Src == service.TaskSrcImg { // 绘画任务
|
||||
if task.Src == types.TaskSrcImg { // 绘画任务
|
||||
var job model.MidJourneyJob
|
||||
res := h.db.Where("id = ?", task.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
@ -191,7 +191,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
|
||||
}
|
||||
}
|
||||
|
||||
} else if task.Src == service.TaskSrcChat { // 聊天任务
|
||||
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
||||
wsClient := h.App.MjTaskClients.Get(task.SessionId)
|
||||
if data.Status == Finished {
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
@ -342,7 +342,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Image.String(),
|
||||
Type: types.TaskImage.String(),
|
||||
UserId: userId,
|
||||
Progress: 0,
|
||||
Prompt: prompt,
|
||||
@ -353,11 +353,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Src: service.TaskSrcImg,
|
||||
Type: service.Image,
|
||||
Src: types.TaskSrcImg,
|
||||
Type: types.TaskImage,
|
||||
Prompt: prompt,
|
||||
UserId: userId,
|
||||
})
|
||||
@ -401,10 +401,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
src := service.TaskSrc(data.Src)
|
||||
if src == service.TaskSrcImg {
|
||||
src := types.TaskSrc(data.Src)
|
||||
if src == types.TaskSrcImg {
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Upscale.String(),
|
||||
Type: types.TaskUpscale.String(),
|
||||
UserId: userId,
|
||||
Hash: data.MessageHash,
|
||||
Progress: 0,
|
||||
@ -428,11 +428,11 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Src: src,
|
||||
Type: service.Upscale,
|
||||
Type: types.TaskUpscale,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
RoleId: data.RoleId,
|
||||
@ -470,10 +470,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
src := service.TaskSrc(data.Src)
|
||||
if src == service.TaskSrcImg {
|
||||
src := types.TaskSrc(data.Src)
|
||||
if src == types.TaskSrcImg {
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Variation.String(),
|
||||
Type: types.TaskVariation.String(),
|
||||
UserId: userId,
|
||||
ImgURL: "",
|
||||
Hash: data.MessageHash,
|
||||
@ -498,11 +498,11 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Src: src,
|
||||
Type: service.Variation,
|
||||
Type: types.TaskVariation,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
RoleId: data.RoleId,
|
||||
|
315
api/handler/sd_handler.go
Normal file
315
api/handler/sd_handler.go
Normal file
@ -0,0 +1,315 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SdJobHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
mjService *service.MjService
|
||||
uploaderManager *oss.UploaderManager
|
||||
lock sync.Mutex
|
||||
clients *types.LMap[string, *types.WsClient]
|
||||
}
|
||||
|
||||
func NewSdJobHandler(
|
||||
app *core.AppServer,
|
||||
client *redis.Client,
|
||||
db *gorm.DB,
|
||||
manager *oss.UploaderManager,
|
||||
mjService *service.MjService) *MidJourneyHandler {
|
||||
h := MidJourneyHandler{
|
||||
redis: client,
|
||||
db: db,
|
||||
uploaderManager: manager,
|
||||
lock: sync.Mutex{},
|
||||
mjService: mjService,
|
||||
clients: types.NewLMap[string, *types.WsClient](),
|
||||
}
|
||||
h.App = app
|
||||
return &h
|
||||
}
|
||||
|
||||
// Client WebSocket 客户端,用于通知任务状态变更
|
||||
func (h *SdJobHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionId := c.Query("session_id")
|
||||
client := types.NewWsClient(ws)
|
||||
// 删除旧的连接
|
||||
h.clients.Delete(sessionId)
|
||||
h.clients.Put(sessionId, client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||
}
|
||||
|
||||
type sdNotifyData struct {
|
||||
TaskId string
|
||||
ImageName string
|
||||
ImageData string
|
||||
Progress int
|
||||
Seed string
|
||||
Success bool
|
||||
Message string
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) Notify(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token != h.App.Config.ExtConfig.Token {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
var data sdNotifyData
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
logger.Debugf("收到 MidJourney 回调请求:%+v", data)
|
||||
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
err, finished := h.notifyHandler(c, data)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解除任务锁定
|
||||
if finished && (data.Progress == 100) {
|
||||
h.redis.Del(c, service.MjRunningJobKey)
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
|
||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
var task types.SdTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var job model.SdJob
|
||||
res := h.db.Where("id = ?", task.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return nil, false
|
||||
}
|
||||
job.Params = utils.JsonEncode(task.Params)
|
||||
job.ReferenceId = data.ImageData
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
job.Hash = data.Image.Hash
|
||||
|
||||
// 任务完成,将最终的图片下载下来
|
||||
if data.Progress == 100 {
|
||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return err, false
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
} else {
|
||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||
job.ImgURL = data.Image.URL
|
||||
}
|
||||
res = h.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", res.Error)
|
||||
return res.Error, false
|
||||
}
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
if data.Progress < 100 {
|
||||
image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
|
||||
// 推送任务到前端
|
||||
client := h.clients.Get(task.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新用户剩余绘图次数
|
||||
if data.Progress == 100 {
|
||||
h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
}
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||
user, err := utils.GetLoginUser(c, h.db)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
return false
|
||||
}
|
||||
|
||||
if user.ImgCalls <= 0 {
|
||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
var data struct {
|
||||
SessionId string `json:"session_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Rate string `json:"rate"`
|
||||
Model string `json:"model"`
|
||||
Chaos int `json:"chaos"`
|
||||
Raw bool `json:"raw"`
|
||||
Seed int64 `json:"seed"`
|
||||
Stylize int `json:"stylize"`
|
||||
Img string `json:"img"`
|
||||
Weight float32 `json:"weight"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if !h.checkLimits(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var prompt = data.Prompt
|
||||
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
||||
prompt += " --ar " + data.Rate
|
||||
}
|
||||
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
|
||||
prompt += fmt.Sprintf(" --seed %d", data.Seed)
|
||||
}
|
||||
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
|
||||
prompt += fmt.Sprintf(" --s %d", data.Stylize)
|
||||
}
|
||||
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
|
||||
prompt += fmt.Sprintf(" --c %d", data.Chaos)
|
||||
}
|
||||
if data.Img != "" {
|
||||
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
|
||||
if data.Weight > 0 {
|
||||
prompt += fmt.Sprintf(" --iw %f", data.Weight)
|
||||
}
|
||||
}
|
||||
if data.Raw {
|
||||
prompt += " --style raw"
|
||||
}
|
||||
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
|
||||
prompt += data.Model
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Image.String(),
|
||||
UserId: userId,
|
||||
Progress: 0,
|
||||
Prompt: prompt,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.db.Create(&job); res.Error != nil {
|
||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Src: service.TaskSrcImg,
|
||||
Type: service.Image,
|
||||
Prompt: prompt,
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
// JobList 获取 MJ 任务列表
|
||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||
status := h.GetInt(c, "status", 0)
|
||||
var items []model.MidJourneyJob
|
||||
var res *gorm.DB
|
||||
userId, _ := c.Get(types.LoginUserID)
|
||||
if status == 1 {
|
||||
res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
|
||||
} else {
|
||||
res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
|
||||
}
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, types.NoData)
|
||||
return
|
||||
}
|
||||
|
||||
var jobs = make([]vo.MidJourneyJob, 0)
|
||||
for _, item := range items {
|
||||
var job vo.MidJourneyJob
|
||||
err := utils.CopyObject(item, &job)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if item.Progress < 100 {
|
||||
// 30 分钟还没完成的任务直接删除
|
||||
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
|
||||
h.db.Delete(&item)
|
||||
continue
|
||||
}
|
||||
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
||||
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
}
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
resp.SUCCESS(c, jobs)
|
||||
}
|
@ -21,41 +21,6 @@ var logger = logger2.GetLogger()
|
||||
|
||||
const MjRunningJobKey = "MidJourney_Running_Job"
|
||||
|
||||
type TaskType string
|
||||
|
||||
func (t TaskType) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
const (
|
||||
Image = TaskType("image")
|
||||
Upscale = TaskType("upscale")
|
||||
Variation = TaskType("variation")
|
||||
)
|
||||
|
||||
type TaskSrc string
|
||||
|
||||
const (
|
||||
TaskSrcChat = TaskSrc("chat")
|
||||
TaskSrcImg = TaskSrc("img")
|
||||
)
|
||||
|
||||
type MjTask struct {
|
||||
Id int `json:"id"`
|
||||
SessionId string `json:"session_id"`
|
||||
Src TaskSrc `json:"src"`
|
||||
Type TaskType `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
ChatId string `json:"chat_id,omitempty"`
|
||||
RoleId int `json:"role_id,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Index int32 `json:"index,omitempty"`
|
||||
MessageId string `json:"message_id,omitempty"`
|
||||
MessageHash string `json:"message_hash,omitempty"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
}
|
||||
|
||||
type MjService struct {
|
||||
config types.ChatPlusExtConfig
|
||||
client *req.Client
|
||||
@ -78,11 +43,11 @@ func (s *MjService) Run() {
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||
if err == nil {
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var task MjTask
|
||||
var task types.MjTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
@ -90,17 +55,17 @@ func (s *MjService) Run() {
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
switch task.Type {
|
||||
case Image:
|
||||
case types.TaskImage:
|
||||
err = s.image(task.Prompt)
|
||||
break
|
||||
case Upscale:
|
||||
case types.TaskUpscale:
|
||||
err = s.upscale(MjUpscaleReq{
|
||||
Index: task.Index,
|
||||
MessageId: task.MessageId,
|
||||
MessageHash: task.MessageHash,
|
||||
})
|
||||
break
|
||||
case Variation:
|
||||
case types.TaskVariation:
|
||||
err = s.variation(MjVariationReq{
|
||||
Index: task.Index,
|
||||
MessageId: task.MessageId,
|
||||
@ -124,7 +89,7 @@ func (s *MjService) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MjService) PushTask(task MjTask) {
|
||||
func (s *MjService) PushTask(task types.MjTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
95
api/service/sd_service.go
Normal file
95
api/service/sd_service.go
Normal file
@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
|
||||
const SdRunningJobKey = "StableDiffusion_Running_Job"
|
||||
|
||||
type SdService struct {
|
||||
config types.ChatPlusExtConfig
|
||||
client *req.Client
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
|
||||
return &SdService{
|
||||
config: appConfig.ExtConfig,
|
||||
redis: client,
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
|
||||
client: req.C().SetTimeout(30 * time.Second)}
|
||||
}
|
||||
|
||||
func (s *SdService) Run() {
|
||||
logger.Info("Starting StableDiffusion job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var task types.SdTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
err = s.txt2img(task.Params)
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SdService) PushTask(task types.SdTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *SdService) txt2img(params types.SdParams) error {
|
||||
logger.Infof("SD 绘画参数:%+v", params)
|
||||
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", s.config.Token).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(params).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("%v%v", r.String(), err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return errors.New(res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
20
api/store/model/sd_job.go
Normal file
20
api/store/model/sd_job.go
Normal file
@ -0,0 +1,20 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type SdJob struct {
|
||||
Id uint `gorm:"primarykey;column:id"`
|
||||
Type string
|
||||
UserId int
|
||||
TaskId string
|
||||
ImgURL string
|
||||
Progress int
|
||||
Prompt string
|
||||
Params string
|
||||
Started bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func (SdJob) TableName() string {
|
||||
return "chatgpt_sd_jobs"
|
||||
}
|
19
api/store/vo/sd_job.go
Normal file
19
api/store/vo/sd_job.go
Normal file
@ -0,0 +1,19 @@
|
||||
package vo
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SdJob struct {
|
||||
Id uint `json:"id"`
|
||||
Type string `json:"type"`
|
||||
UserId int `json:"user_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
ImgURL string `json:"img_url"`
|
||||
Params types.SdParams `json:"params"`
|
||||
Progress int `json:"progress"`
|
||||
Prompt string `json:"prompt"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Started bool `json:"started"`
|
||||
}
|
@ -1,2 +1,25 @@
|
||||
ALTER TABLE `chatgpt_mj_jobs` ADD `started` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '任务是否开始' AFTER `progress`;
|
||||
UPDATE `chatgpt_mj_jobs` SET started = 1
|
||||
|
||||
-- 创建 SD 绘图任务表
|
||||
CREATE TABLE `chatgpt_sd_jobs` (
|
||||
`id` int NOT NULL,
|
||||
`user_id` int NOT NULL COMMENT '用户 ID',
|
||||
`type` varchar(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT 'txt2img' COMMENT '任务类别',
|
||||
`task_id` char(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '任务 ID',
|
||||
`prompt` varchar(2000) NOT NULL COMMENT '会话提示词',
|
||||
`img_url` varchar(255) DEFAULT NULL COMMENT '图片URL',
|
||||
`params` text CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci COMMENT '绘画参数json',
|
||||
`progress` smallint DEFAULT '0' COMMENT '任务进度',
|
||||
`started` tinyint(1) NOT NULL DEFAULT '0' COMMENT '任务是否开始',
|
||||
`created_at` datetime NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='StableDiffusion 任务表';
|
||||
--
|
||||
-- 表的索引 `chatgpt_sd_jobs`
|
||||
--
|
||||
ALTER TABLE `chatgpt_sd_jobs`
|
||||
ADD PRIMARY KEY (`id`),
|
||||
ADD UNIQUE KEY `task_id` (`task_id`);
|
||||
|
||||
ALTER TABLE `chatgpt_sd_jobs`
|
||||
MODIFY `id` int NOT NULL AUTO_INCREMENT;
|
Loading…
Reference in New Issue
Block a user