refactor midjourney service, use api key in database

This commit is contained in:
RockYang
2024-08-06 18:30:57 +08:00
parent 72b1515b68
commit 6a8b4ee2f1
29 changed files with 585 additions and 1203 deletions

View File

@@ -30,15 +30,15 @@ import (
type MidJourneyHandler struct {
BaseHandler
pool *mj.ServicePool
service *mj.Service
snowflake *service.Snowflake
uploader *oss.UploaderManager
}
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler {
return &MidJourneyHandler{
snowflake: snowflake,
pool: pool,
service: service,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
@@ -59,11 +59,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false
}
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true
}
@@ -85,7 +80,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
}
client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
@@ -201,7 +196,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return
}
h.pool.PushTask(types.MjTask{
h.service.PushTask(types.MjTask{
Id: job.Id,
TaskId: taskId,
Type: types.TaskType(data.TaskType),
@@ -210,9 +205,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Params: params,
UserId: userId,
ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -273,7 +269,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return
}
h.pool.PushTask(types.MjTask{
h.service.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskUpscale,
UserId: userId,
@@ -281,9 +277,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Index: data.Index,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -337,7 +334,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return
}
h.pool.PushTask(types.MjTask{
h.service.PushTask(types.MjTask{
Id: job.Id,
Type: types.TaskVariation,
UserId: userId,
@@ -345,9 +342,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
ChannelId: data.ChannelId,
MessageId: data.MessageId,
MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
})
client := h.pool.Clients.Get(uint(job.UserId))
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
@@ -500,7 +498,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err)
}
client := h.pool.Clients.Get(uint(job.UserId))
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}