mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 02:03:42 +08:00
feat: add implements for stable diffusion service
This commit is contained in:
@@ -66,7 +66,7 @@ func NewMidJourneyHandler(
|
||||
return &h
|
||||
}
|
||||
|
||||
type notifyData struct {
|
||||
type mjNotifyData struct {
|
||||
MessageId string `json:"message_id"`
|
||||
ReferenceId string `json:"reference_id"`
|
||||
Image Image `json:"image"`
|
||||
@@ -98,7 +98,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
resp.NotAuth(c)
|
||||
return
|
||||
}
|
||||
var data notifyData
|
||||
var data mjNotifyData
|
||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
@@ -122,14 +122,14 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (error, bool) {
|
||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) {
|
||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
var task service.MjTask
|
||||
var task types.MjTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
@@ -143,7 +143,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if task.Src == service.TaskSrcImg { // 绘画任务
|
||||
if task.Src == types.TaskSrcImg { // 绘画任务
|
||||
var job model.MidJourneyJob
|
||||
res := h.db.Where("id = ?", task.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
@@ -191,7 +191,7 @@ func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data notifyData) (erro
|
||||
}
|
||||
}
|
||||
|
||||
} else if task.Src == service.TaskSrcChat { // 聊天任务
|
||||
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
||||
wsClient := h.App.MjTaskClients.Get(task.SessionId)
|
||||
if data.Status == Finished {
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
@@ -342,7 +342,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Image.String(),
|
||||
Type: types.TaskImage.String(),
|
||||
UserId: userId,
|
||||
Progress: 0,
|
||||
Prompt: prompt,
|
||||
@@ -353,11 +353,11 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Src: service.TaskSrcImg,
|
||||
Type: service.Image,
|
||||
Src: types.TaskSrcImg,
|
||||
Type: types.TaskImage,
|
||||
Prompt: prompt,
|
||||
UserId: userId,
|
||||
})
|
||||
@@ -401,10 +401,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
src := service.TaskSrc(data.Src)
|
||||
if src == service.TaskSrcImg {
|
||||
src := types.TaskSrc(data.Src)
|
||||
if src == types.TaskSrcImg {
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Upscale.String(),
|
||||
Type: types.TaskUpscale.String(),
|
||||
UserId: userId,
|
||||
Hash: data.MessageHash,
|
||||
Progress: 0,
|
||||
@@ -428,11 +428,11 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Src: src,
|
||||
Type: service.Upscale,
|
||||
Type: types.TaskUpscale,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
RoleId: data.RoleId,
|
||||
@@ -470,10 +470,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
src := service.TaskSrc(data.Src)
|
||||
if src == service.TaskSrcImg {
|
||||
src := types.TaskSrc(data.Src)
|
||||
if src == types.TaskSrcImg {
|
||||
job := model.MidJourneyJob{
|
||||
Type: service.Variation.String(),
|
||||
Type: types.TaskVariation.String(),
|
||||
UserId: userId,
|
||||
ImgURL: "",
|
||||
Hash: data.MessageHash,
|
||||
@@ -498,11 +498,11 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mjService.PushTask(service.MjTask{
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Src: src,
|
||||
Type: service.Variation,
|
||||
Type: types.TaskVariation,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
RoleId: data.RoleId,
|
||||
|
||||
Reference in New Issue
Block a user