mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 18:53:43 +08:00
refactor midjourney service, use api key in database
This commit is contained in:
@@ -30,15 +30,15 @@ import (
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
pool *mj.ServicePool
|
||||
service *mj.Service
|
||||
snowflake *service.Snowflake
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler {
|
||||
return &MidJourneyHandler{
|
||||
snowflake: snowflake,
|
||||
pool: pool,
|
||||
service: service,
|
||||
uploader: manager,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
@@ -59,11 +59,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if !h.pool.HasAvailableService() {
|
||||
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
@@ -85,7 +80,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.pool.Clients.Put(uint(userId), client)
|
||||
h.service.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
@@ -201,7 +196,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
h.service.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
TaskId: taskId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
@@ -210,9 +205,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -273,7 +269,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
h.service.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
Type: types.TaskUpscale,
|
||||
UserId: userId,
|
||||
@@ -281,9 +277,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -337,7 +334,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
h.service.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
Type: types.TaskVariation,
|
||||
UserId: userId,
|
||||
@@ -345,9 +342,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
ChannelId: data.ChannelId,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -500,7 +498,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user