mj websocket refactor is ready

This commit is contained in:
RockYang
2024-09-29 07:51:08 +08:00
parent 8fffa60569
commit 00a8bc6784
7 changed files with 63 additions and 149 deletions

View File

@@ -8,7 +8,6 @@ package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"fmt"
"geekai/core"
"geekai/core/types"
@@ -67,6 +66,7 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
func (h *MidJourneyHandler) Image(c *gin.Context) {
var data struct {
TaskType string `json:"task_type"`
ClientId string `json:"client_id"`
Prompt string `json:"prompt"`
NegPrompt string `json:"neg_prompt"`
Rate string `json:"rate"`
@@ -177,6 +177,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
h.mjService.PushTask(types.MjTask{
Id: job.Id,
ClientId: data.ClientId,
TaskId: taskId,
Type: types.TaskType(data.TaskType),
Prompt: data.Prompt,
@@ -187,11 +188,6 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
err = h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
@@ -208,6 +204,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
type reqVo struct {
Index int `json:"index"`
ClientId string `json:"client_id"`
ChannelId string `json:"channel_id"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
@@ -244,6 +241,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
h.mjService.PushTask(types.MjTask{
Id: job.Id,
ClientId: data.ClientId,
Type: types.TaskUpscale,
UserId: userId,
ChannelId: data.ChannelId,
@@ -253,11 +251,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
// update user's power
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
@@ -305,6 +298,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
h.mjService.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskVariation,
ClientId: data.ClientId,
UserId: userId,
Index: data.Index,
ChannelId: data.ChannelId,
@@ -313,11 +307,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
Mode: h.App.SysConfig.MjMode,
})
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
err := h.userService.DecreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerConsume,
Model: "mid-journey",
@@ -397,14 +386,6 @@ func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize
if err != nil {
continue
}
if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
jobs = append(jobs, job)
}
return nil, vo.NewPage(total, page, pageSize, jobs)
@@ -449,11 +430,6 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err)
}
client := h.mjService.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}