feat: add implements for stable diffusion service

This commit is contained in:
RockYang
2023-09-26 18:16:51 +08:00
parent db0a79da93
commit c86169022a
8 changed files with 569 additions and 62 deletions

View File

@@ -66,7 +66,7 @@ func NewMidJourneyHandler(
return &h
}
type notifyData struct {
type mjNotifyData struct {
MessageId string `json:"message_id"`
ReferenceId string `json:"reference_id"`
Image Image `json:"image"`
@@ -98,7 +98,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
resp.NotAuth(c)
return
}
var data notifyData
var data mjNotifyData
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
@@ -122,14 +122,14 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
}
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) {
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
if err != nil { // 过期任务,丢弃
logger.Warn("任务已过期:", err)
return nil, true
}
var task service.MjTask
var task types.MjTask
err = utils.JsonDecode(taskString, &task)
if err != nil { // 非标准任务,丢弃
logger.Warn("任务解析失败:", err)
@@ -143,7 +143,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
return nil, false
}
if task.Src == service.TaskSrcImg { // 绘画任务
if task.Src == types.TaskSrcImg { // 绘画任务
var job model.MidJourneyJob
res := h.db.Where("id = ?", task.Id).First(&job)
if res.Error != nil {
@@ -191,7 +191,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
}
}
} else if task.Src == service.TaskSrcChat { // 聊天任务
} else if task.Src == types.TaskSrcChat { // 聊天任务
wsClient := h.App.MjTaskClients.Get(task.SessionId)
if data.Status == Finished {
if wsClient != nil && data.ReferenceId != "" {
@@ -342,7 +342,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.MidJourneyJob{
Type: service.Image.String(),
Type: types.TaskImage.String(),
UserId: userId,
Progress: 0,
Prompt: prompt,
@@ -353,11 +353,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
h.mjService.PushTask(service.MjTask{
h.mjService.PushTask(types.MjTask{
Id: int(job.Id),
SessionId: data.SessionId,
Src: service.TaskSrcImg,
Type: service.Image,
Src: types.TaskSrcImg,
Type: types.TaskImage,
Prompt: prompt,
UserId: userId,
})
@@ -401,10 +401,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg {
src := types.TaskSrc(data.Src)
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: service.Upscale.String(),
Type: types.TaskUpscale.String(),
UserId: userId,
Hash: data.MessageHash,
Progress: 0,
@@ -428,11 +428,11 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
}
}
}
h.mjService.PushTask(service.MjTask{
h.mjService.PushTask(types.MjTask{
Id: jobId,
SessionId: data.SessionId,
Src: src,
Type: service.Upscale,
Type: types.TaskUpscale,
Prompt: data.Prompt,
UserId: userId,
RoleId: data.RoleId,
@@ -470,10 +470,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
src := service.TaskSrc(data.Src)
if src == service.TaskSrcImg {
src := types.TaskSrc(data.Src)
if src == types.TaskSrcImg {
job := model.MidJourneyJob{
Type: service.Variation.String(),
Type: types.TaskVariation.String(),
UserId: userId,
ImgURL: "",
Hash: data.MessageHash,
@@ -498,11 +498,11 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
}
}
h.mjService.PushTask(service.MjTask{
h.mjService.PushTask(types.MjTask{
Id: jobId,
SessionId: data.SessionId,
Src: src,
Type: service.Variation,
Type: types.TaskVariation,
Prompt: data.Prompt,
UserId: userId,
RoleId: data.RoleId,