mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 02:03:42 +08:00
refactor midjourney service, use api key in database
This commit is contained in:
@@ -8,6 +8,7 @@ package admin
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
@@ -45,6 +46,12 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
|
||||
role.Id = data.Id
|
||||
if data.CreatedAt > 0 {
|
||||
role.CreatedAt = time.Unix(data.CreatedAt, 0)
|
||||
} else {
|
||||
err = h.DB.Where("marker", data.Key).First(&role).Error
|
||||
if err == nil {
|
||||
resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
|
||||
return
|
||||
}
|
||||
}
|
||||
err = h.DB.Save(&role).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
"geekai/service"
|
||||
"geekai/service/mj"
|
||||
"geekai/service/sd"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
@@ -28,15 +27,13 @@ type ConfigHandler struct {
|
||||
handler.BaseHandler
|
||||
levelDB *store.LevelDB
|
||||
licenseService *service.LicenseService
|
||||
mjServicePool *mj.ServicePool
|
||||
sdServicePool *sd.ServicePool
|
||||
}
|
||||
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler {
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler {
|
||||
return &ConfigHandler{
|
||||
BaseHandler: handler.BaseHandler{App: app, DB: db},
|
||||
levelDB: levelDB,
|
||||
mjServicePool: mjPool,
|
||||
sdServicePool: sdPool,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
@@ -146,58 +143,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
|
||||
license := h.licenseService.GetLicense()
|
||||
resp.SUCCESS(c, license)
|
||||
}
|
||||
|
||||
// GetAppConfig 获取内置配置
|
||||
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
|
||||
resp.SUCCESS(c, gin.H{
|
||||
"mj_plus": h.App.Config.MjPlusConfigs,
|
||||
"mj_proxy": h.App.Config.MjProxyConfigs,
|
||||
"sd": h.App.Config.SdConfigs,
|
||||
})
|
||||
}
|
||||
|
||||
// SaveDrawingConfig 保存AI绘画配置
|
||||
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
|
||||
var data struct {
|
||||
Sd []types.StableDiffusionConfig `json:"sd"`
|
||||
MjPlus []types.MjPlusConfig `json:"mj_plus"`
|
||||
MjProxy []types.MjProxyConfig `json:"mj_proxy"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
changed := false
|
||||
if configChanged(data.Sd, h.App.Config.SdConfigs) {
|
||||
logger.Debugf("SD 配置变动了")
|
||||
h.App.Config.SdConfigs = data.Sd
|
||||
h.sdServicePool.InitServices(data.Sd)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
|
||||
logger.Debugf("MidJourney 配置变动了")
|
||||
h.App.Config.MjPlusConfigs = data.MjPlus
|
||||
h.App.Config.MjProxyConfigs = data.MjProxy
|
||||
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
err := core.SaveConfig(h.App.Config)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "更新配置文档失败!")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
|
||||
}
|
||||
|
||||
func configChanged(c1 interface{}, c2 interface{}) bool {
|
||||
encode1 := utils.JsonEncode(c1)
|
||||
encode2 := utils.JsonEncode(c2)
|
||||
return utils.Md5(encode1) != utils.Md5(encode2)
|
||||
}
|
||||
|
||||
@@ -30,15 +30,15 @@ import (
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
pool *mj.ServicePool
|
||||
service *mj.Service
|
||||
snowflake *service.Snowflake
|
||||
uploader *oss.UploaderManager
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler {
|
||||
return &MidJourneyHandler{
|
||||
snowflake: snowflake,
|
||||
pool: pool,
|
||||
service: service,
|
||||
uploader: manager,
|
||||
BaseHandler: BaseHandler{
|
||||
App: app,
|
||||
@@ -59,11 +59,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if !h.pool.HasAvailableService() {
|
||||
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
@@ -85,7 +80,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||
}
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
h.pool.Clients.Put(uint(userId), client)
|
||||
h.service.Clients.Put(uint(userId), client)
|
||||
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||
}
|
||||
|
||||
@@ -201,7 +196,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
h.service.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
TaskId: taskId,
|
||||
Type: types.TaskType(data.TaskType),
|
||||
@@ -210,9 +205,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
Params: params,
|
||||
UserId: userId,
|
||||
ImgArr: data.ImgArr,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -273,7 +269,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
h.service.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
Type: types.TaskUpscale,
|
||||
UserId: userId,
|
||||
@@ -281,9 +277,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -337,7 +334,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.pool.PushTask(types.MjTask{
|
||||
h.service.PushTask(types.MjTask{
|
||||
Id: job.Id,
|
||||
Type: types.TaskVariation,
|
||||
UserId: userId,
|
||||
@@ -345,9 +342,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
ChannelId: data.ChannelId,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
Mode: h.App.SysConfig.MjMode,
|
||||
})
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
@@ -500,7 +498,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
|
||||
logger.Error("remove image failed: ", err)
|
||||
}
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
client := h.service.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte("Task Updated"))
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
|
||||
|
||||
client := h.pool.Clients.Get(uint(job.UserId))
|
||||
if client != nil {
|
||||
_ = client.Send([]byte(sd.Finished))
|
||||
_ = client.Send([]byte(service.TaskStatusFinished))
|
||||
}
|
||||
|
||||
resp.SUCCESS(c)
|
||||
|
||||
Reference in New Issue
Block a user