refactor midjourney service, use api key in database

This commit is contained in:
RockYang 2024-08-06 18:30:57 +08:00
parent 72b1515b68
commit 6a8b4ee2f1
29 changed files with 585 additions and 1203 deletions

View File

@ -5,6 +5,7 @@
* 功能优化优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力 * 功能优化优化MJ,SD,DALL-E 任务列表页面,显示失败任务的错误信息,删除失败任务可以恢复扣减算力
* Bug修复修复后台拖动排序组件 Bug * Bug修复修复后台拖动排序组件 Bug
* 功能优化:更新数据库失败时候显示具体的的报错信息 * 功能优化:更新数据库失败时候显示具体的的报错信息
* Bug修复修复管理后台对话详情页内容显示异常问题
## v4.1.1 ## v4.1.1
* Bug修复修复 GPT 模型 function call 调用后没有输出的问题 * Bug修复修复 GPT 模型 function call 调用后没有输出的问题

View File

@ -65,17 +65,6 @@ TikaHost = "http://tika:9998"
SubDir = "" SubDir = ""
Domain = "" Domain = ""
[[MjProxyConfigs]]
Enabled = true
ApiURL = "http://midjourney-proxy:8082"
ApiKey = "sk-geekmaster"
[[MjPlusConfigs]]
Enabled = false
ApiURL = "https://api.chat-plus.net"
Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
ApiKey = "sk-xxx"
[[SdConfigs]] [[SdConfigs]]
Enabled = false Enabled = false
ApiURL = "" ApiURL = ""

View File

@ -24,8 +24,6 @@ type AppConfig struct {
ApiConfig ApiConfig // ChatPlus API authorization configs ApiConfig ApiConfig // ChatPlus API authorization configs
SMS SMSConfig // send mobile message config SMS SMSConfig // send mobile message config
OSS OSSConfig // OSS config OSS OSSConfig // OSS config
MjProxyConfigs []MjProxyConfig // MJ proxy config
MjPlusConfigs []MjPlusConfig // MJ plus config
WeChatBot bool // 是否启用微信机器人 WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool SdConfigs []StableDiffusionConfig // sd AI draw service pool
@ -188,6 +186,7 @@ type SystemConfig struct {
ContextDeep int `json:"context_deep,omitempty"` ContextDeep int `json:"context_deep,omitempty"`
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词 SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
MjMode string `json:"mj_mode"` // midjourney 默认的API模式relax, fast, turbo
IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片 IndexBgURL string `json:"index_bg_url"` // 前端首页背景图片
IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单 IndexNavs []int `json:"index_navs"` // 首页显示的导航菜单

View File

@ -27,7 +27,6 @@ type MjTask struct {
Id uint `json:"id"` Id uint `json:"id"`
TaskId string `json:"task_id"` TaskId string `json:"task_id"`
ImgArr []string `json:"img_arr"` ImgArr []string `json:"img_arr"`
ChannelId string `json:"channel_id"`
Type TaskType `json:"type"` Type TaskType `json:"type"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@ -37,6 +36,8 @@ type MjTask struct {
MessageId string `json:"message_id,omitempty"` MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"` MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"` RetryCount int `json:"retry_count"`
ChannelId string `json:"channel_id"` // 渠道ID用来区分是哪个渠道创建的任务一个任务的 create 和 action 操作必须要再同一个渠道
Mode string `json:"mode"` // 绘画模式relax, fast, turbo
} }
type SdTask struct { type SdTask struct {

View File

@ -8,6 +8,7 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import ( import (
"fmt"
"geekai/core" "geekai/core"
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
@ -45,6 +46,12 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
role.Id = data.Id role.Id = data.Id
if data.CreatedAt > 0 { if data.CreatedAt > 0 {
role.CreatedAt = time.Unix(data.CreatedAt, 0) role.CreatedAt = time.Unix(data.CreatedAt, 0)
} else {
err = h.DB.Where("marker", data.Key).First(&role).Error
if err == nil {
resp.ERROR(c, fmt.Sprintf("角色 %s 已存在", data.Key))
return
}
} }
err = h.DB.Save(&role).Error err = h.DB.Save(&role).Error
if err != nil { if err != nil {

View File

@ -12,7 +12,6 @@ import (
"geekai/core/types" "geekai/core/types"
"geekai/handler" "geekai/handler"
"geekai/service" "geekai/service"
"geekai/service/mj"
"geekai/service/sd" "geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
@ -28,15 +27,13 @@ type ConfigHandler struct {
handler.BaseHandler handler.BaseHandler
levelDB *store.LevelDB levelDB *store.LevelDB
licenseService *service.LicenseService licenseService *service.LicenseService
mjServicePool *mj.ServicePool
sdServicePool *sd.ServicePool sdServicePool *sd.ServicePool
} }
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, sdPool *sd.ServicePool) *ConfigHandler {
return &ConfigHandler{ return &ConfigHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db}, BaseHandler: handler.BaseHandler{App: app, DB: db},
levelDB: levelDB, levelDB: levelDB,
mjServicePool: mjPool,
sdServicePool: sdPool, sdServicePool: sdPool,
licenseService: licenseService, licenseService: licenseService,
} }
@ -146,58 +143,3 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
license := h.licenseService.GetLicense() license := h.licenseService.GetLicense()
resp.SUCCESS(c, license) resp.SUCCESS(c, license)
} }
// GetAppConfig 获取内置配置
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
resp.SUCCESS(c, gin.H{
"mj_plus": h.App.Config.MjPlusConfigs,
"mj_proxy": h.App.Config.MjProxyConfigs,
"sd": h.App.Config.SdConfigs,
})
}
// SaveDrawingConfig 保存AI绘画配置
func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) {
var data struct {
Sd []types.StableDiffusionConfig `json:"sd"`
MjPlus []types.MjPlusConfig `json:"mj_plus"`
MjProxy []types.MjProxyConfig `json:"mj_proxy"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
changed := false
if configChanged(data.Sd, h.App.Config.SdConfigs) {
logger.Debugf("SD 配置变动了")
h.App.Config.SdConfigs = data.Sd
h.sdServicePool.InitServices(data.Sd)
changed = true
}
if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) {
logger.Debugf("MidJourney 配置变动了")
h.App.Config.MjPlusConfigs = data.MjPlus
h.App.Config.MjProxyConfigs = data.MjProxy
h.mjServicePool.InitServices(data.MjPlus, data.MjProxy)
changed = true
}
if changed {
err := core.SaveConfig(h.App.Config)
if err != nil {
resp.ERROR(c, "更新配置文档失败!")
return
}
}
resp.SUCCESS(c)
}
func configChanged(c1 interface{}, c2 interface{}) bool {
encode1 := utils.JsonEncode(c1)
encode2 := utils.JsonEncode(c2)
return utils.Md5(encode1) != utils.Md5(encode2)
}

View File

@ -30,15 +30,15 @@ import (
type MidJourneyHandler struct { type MidJourneyHandler struct {
BaseHandler BaseHandler
pool *mj.ServicePool service *mj.Service
snowflake *service.Snowflake snowflake *service.Snowflake
uploader *oss.UploaderManager uploader *oss.UploaderManager
} }
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, service *mj.Service, manager *oss.UploaderManager) *MidJourneyHandler {
return &MidJourneyHandler{ return &MidJourneyHandler{
snowflake: snowflake, snowflake: snowflake,
pool: pool, service: service,
uploader: manager, uploader: manager,
BaseHandler: BaseHandler{ BaseHandler: BaseHandler{
App: app, App: app,
@ -59,11 +59,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
return false return false
} }
if !h.pool.HasAvailableService() {
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
return false
}
return true return true
} }
@ -85,7 +80,7 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
} }
client := types.NewWsClient(ws) client := types.NewWsClient(ws)
h.pool.Clients.Put(uint(userId), client) h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
} }
@ -201,7 +196,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ h.service.PushTask(types.MjTask{
Id: job.Id, Id: job.Id,
TaskId: taskId, TaskId: taskId,
Type: types.TaskType(data.TaskType), Type: types.TaskType(data.TaskType),
@ -210,9 +205,10 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
Params: params, Params: params,
UserId: userId, UserId: userId,
ImgArr: data.ImgArr, ImgArr: data.ImgArr,
Mode: h.App.SysConfig.MjMode,
}) })
client := h.pool.Clients.Get(uint(job.UserId)) client := h.service.Clients.Get(uint(job.UserId))
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
@ -273,7 +269,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ h.service.PushTask(types.MjTask{
Id: job.Id, Id: job.Id,
Type: types.TaskUpscale, Type: types.TaskUpscale,
UserId: userId, UserId: userId,
@ -281,9 +277,10 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
Index: data.Index, Index: data.Index,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
}) })
client := h.pool.Clients.Get(uint(job.UserId)) client := h.service.Clients.Get(uint(job.UserId))
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
@ -337,7 +334,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
return return
} }
h.pool.PushTask(types.MjTask{ h.service.PushTask(types.MjTask{
Id: job.Id, Id: job.Id,
Type: types.TaskVariation, Type: types.TaskVariation,
UserId: userId, UserId: userId,
@ -345,9 +342,10 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
ChannelId: data.ChannelId, ChannelId: data.ChannelId,
MessageId: data.MessageId, MessageId: data.MessageId,
MessageHash: data.MessageHash, MessageHash: data.MessageHash,
Mode: h.App.SysConfig.MjMode,
}) })
client := h.pool.Clients.Get(uint(job.UserId)) client := h.service.Clients.Get(uint(job.UserId))
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }
@ -500,7 +498,7 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
logger.Error("remove image failed: ", err) logger.Error("remove image failed: ", err)
} }
client := h.pool.Clients.Get(uint(job.UserId)) client := h.service.Clients.Get(uint(job.UserId))
if client != nil { if client != nil {
_ = client.Send([]byte("Task Updated")) _ = client.Send([]byte("Task Updated"))
} }

View File

@ -330,7 +330,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
client := h.pool.Clients.Get(uint(job.UserId)) client := h.pool.Clients.Get(uint(job.UserId))
if client != nil { if client != nil {
_ = client.Send([]byte(sd.Finished)) _ = client.Send([]byte(service.TaskStatusFinished))
} }
resp.SUCCESS(c) resp.SUCCESS(c)

View File

@ -161,13 +161,12 @@ func main() {
return service.NewCaptchaService(config.ApiConfig) return service.NewCaptchaService(config.ApiConfig)
}), }),
fx.Provide(oss.NewUploaderManager), fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService), fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) { fx.Invoke(func(s *dalle.Service) {
service.Run() s.Run()
service.CheckTaskNotify() s.CheckTaskNotify()
service.DownloadImages() s.DownloadImages()
service.CheckTaskStatus() s.CheckTaskStatus()
}), }),
// 邮件服务 // 邮件服务
@ -190,14 +189,13 @@ func main() {
}), }),
// MidJourney service pool // MidJourney service pool
fx.Provide(mj.NewServicePool), fx.Provide(mj.NewService),
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) { fx.Provide(mj.NewClient),
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs) fx.Invoke(func(s *mj.Service) {
if pool.HasAvailableService() { s.Run()
pool.DownloadImages() s.SyncTaskProgress()
pool.CheckTaskNotify() s.CheckTaskNotify()
pool.SyncTaskProgress() s.DownloadImages()
}
}), }),
// Stable Diffusion 机器人 // Stable Diffusion 机器人
@ -317,8 +315,6 @@ func main() {
group.GET("config/get", h.Get) group.GET("config/get", h.Get)
group.POST("active", h.Active) group.POST("active", h.Active)
group.GET("config/get/license", h.GetLicense) group.GET("config/get/license", h.GetLicense)
group.GET("config/get/app", h.GetAppConfig)
group.POST("config/update/draw", h.SaveDrawingConfig)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/") group := s.Engine.Group("/api/admin/")

View File

@ -14,7 +14,6 @@ import (
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
@ -70,10 +69,10 @@ func (s *Service) Run() {
if err != nil { if err != nil {
logger.Errorf("error with image task: %v", err) logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 101, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
} }
} }
}() }()
@ -191,7 +190,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
return "", fmt.Errorf("err with update database: %v", tx.Error) return "", fmt.Errorf("err with update database: %v", tx.Error)
} }
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: service.TaskStatusFailed})
var content string var content string
if sync { if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
@ -208,7 +207,7 @@ func (s *Service) CheckTaskNotify() {
go func() { go func() {
logger.Info("Running DALL-E task notify checking ...") logger.Info("Running DALL-E task notify checking ...")
for { for {
var message sd.NotifyMessage var message service.NotifyMessage
err := s.notifyQueue.LPop(&message) err := s.notifyQueue.LPop(&message)
if err != nil { if err != nil {
continue continue
@ -239,7 +238,7 @@ func (s *Service) CheckTaskStatus() {
for _, job := range jobs { for _, job := range jobs {
// 超时的任务标记为失败 // 超时的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 { if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = 101 job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时" job.ErrMsg = "任务超时"
s.db.Updates(&job) s.db.Updates(&job)
} }
@ -292,6 +291,6 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil { if res.Error != nil {
return "", err return "", err
} }
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{UserId: userId, JobId: int(jobId), Message: service.TaskStatusFinished})
return imgURL, nil return imgURL, nil
} }

View File

@ -7,15 +7,28 @@ package mj
// * @Author yangjian102621@163.com // * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import "geekai/core/types" import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"io"
"time"
type Client interface { "github.com/gin-gonic/gin"
Imagine(task types.MjTask) (ImageRes, error) )
Blend(task types.MjTask) (ImageRes, error)
SwapFace(task types.MjTask) (ImageRes, error) // Client MidJourney client
Upscale(task types.MjTask) (ImageRes, error) type Client struct {
Variation(task types.MjTask) (ImageRes, error) client *req.Client
QueryTask(taskId string) (QueryRes, error) licenseService *service.LicenseService
db *gorm.DB
} }
type ImageReq struct { type ImageReq struct {
@ -34,6 +47,7 @@ type ImageRes struct {
Properties struct { Properties struct {
} `json:"properties"` } `json:"properties"`
Result string `json:"result"` Result string `json:"result"`
Channel string `json:"channel,omitempty"`
} }
type ErrRes struct { type ErrRes struct {
@ -66,3 +80,184 @@ type QueryRes struct {
Status string `json:"status"` Status string `json:"status"`
SubmitTime int `json:"submitTime"` SubmitTime int `json:"submitTime"`
} }
var logger = logger2.GetLogger()
func NewClient(licenseService *service.LicenseService, db *gorm.DB) *Client {
return &Client{
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
db: db,
}
}
func (c *Client) Imagine(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/imagine", task.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Blend 融图
func (c *Client) Blend(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/submit/blend", task.Mode)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// SwapFace 换脸
func (c *Client) SwapFace(task types.MjTask) (ImageRes, error) {
apiPath := fmt.Sprintf("mj-%s/mj/insight-face/swap", task.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
return c.doRequest(body, apiPath, task.ChannelId)
}
// Upscale 放大指定的图片
func (c *Client) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiPath := fmt.Sprintf("mj-%s/mj/submit/action", task.Mode)
return c.doRequest(body, apiPath, task.ChannelId)
}
func (c *Client) doRequest(body interface{}, apiPath string, channel string) (ImageRes, error) {
var res ImageRes
var errRes ErrRes
session := c.db.Session(&gorm.Session{}).Where("type", "mj").Where("enabled", true)
if channel != "" {
session = session.Where("api_url", channel)
}
var apiKey model.ApiKey
err := session.Order("last_used_at ASC").First(&apiKey).Error
if err != nil {
return ImageRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
if err = c.licenseService.IsValidApiURL(apiKey.ApiURL); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/%s", apiKey.ApiURL, apiPath)
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errMsg := err.Error()
if r != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Error("请求 API 出错:", string(errStr))
errMsg = errMsg + " " + string(errStr)
}
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
// update the api key last used time
if err = c.db.Model(&apiKey).Update("last_used_at", time.Now().Unix()).Error; err != nil {
logger.Error("update api key last used time error: ", err)
}
res.Channel = apiKey.ApiURL
return res, nil
}
func (c *Client) QueryTask(taskId string, channel string) (QueryRes, error) {
var apiKey model.ApiKey
err := c.db.Where("type", "mj").Where("enabled", true).Where("api_url", channel).First(&apiKey).Error
if err != nil {
return QueryRes{}, fmt.Errorf("no available MidJourney api key: %v", err)
}
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", apiKey.ApiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+apiKey.Value).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}

View File

@ -1,211 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/service"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
"time"
"github.com/gin-gonic/gin"
)
// PlusClient MidJourney Plus ProxyClient
type PlusClient struct {
Config types.MjPlusConfig
apiURL string
client *req.Client
licenseService *service.LicenseService
}
func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient {
return &PlusClient{
Config: config,
apiURL: config.ApiURL,
client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
licenseService: licenseService,
}
}
func (c *PlusClient) preCheck() error {
return c.licenseService.IsValidApiURL(c.Config.ApiURL)
}
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
return c.doRequest(body, apiURL)
}
// Blend 融图
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
logger.Info("API URL: ", apiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
return c.doRequest(body, apiURL)
}
// SwapFace 换脸
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
// 生成图片 Base64 编码
if len(task.ImgArr) != 2 {
return ImageRes{}, errors.New("参数错误必须上传2张图片")
}
var sourceBase64 string
var targetBase64 string
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
imageData, err = utils.DownloadImage(task.ImgArr[1], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
}
body := gin.H{
"sourceBase64": sourceBase64,
"targetBase64": targetBase64,
"accountFilter": gin.H{
"instanceId": "",
},
"state": "",
}
return c.doRequest(body, apiURL)
}
// Upscale 放大指定的图片
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
return c.doRequest(body, apiURL)
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
if err := c.preCheck(); err != nil {
return ImageRes{}, err
}
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
return c.doRequest(body, apiURL)
}
func (c *PlusClient) doRequest(body interface{}, apiURL string) (ImageRes, error) {
var res ImageRes
var errRes ErrRes
logger.Info("API URL: ", apiURL)
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
errMsg := err.Error()
if r != nil {
errStr, _ := io.ReadAll(r.Body)
logger.Error("请求 API 出错:", string(errStr))
errMsg = errMsg + " " + string(errStr)
}
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", errMsg)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &PlusClient{}

View File

@ -1,215 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss"
"geekai/service/sd"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"github.com/go-redis/redis/v8"
"strings"
"time"
"gorm.io/gorm"
)
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
uploaderManager *oss.UploaderManager
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
licenseService *service.LicenseService
}
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
services: services,
uploaderManager: manager,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
licenseService: licenseService,
}
}
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for _, config := range plusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config, p.licenseService)
name := utils.Md5(config.ApiURL)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
plusService.Run()
}()
p.services = append(p.services, plusService)
}
// for mid-journey proxy
for _, config := range proxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := utils.Md5(config.ApiURL)
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
proxyService.Run()
}()
p.services = append(p.services, proxyService)
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
var message sd.NotifyMessage
err := p.notifyQueue.LPop(&message)
if err != nil {
continue
}
cli := p.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (p *ServicePool) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
mjService := p.getService(v.ChannelId)
if mjService == nil {
logger.Errorf("Invalid task: %+v", v)
continue
}
task, _ := mjService.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := p.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
p.db.Updates(&v)
cli := p.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(sd.Finished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (p *ServicePool) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
p.taskQueue.RPush(task)
}
// HasAvailableService check if it has available mj service in pool
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
// SyncTaskProgress 异步拉取任务
func (p *ServicePool) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
res := p.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
// 5 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = 101
job.ErrMsg = "任务超时"
p.db.Updates(&job)
continue
}
if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
_ = servicePlus.Notify(job)
}
}
time.Sleep(time.Second * 10)
}
}()
}
func (p *ServicePool) getService(name string) *Service {
for _, s := range p.services {
if s.Name == name {
return s
}
}
return nil
}

View File

@ -1,185 +0,0 @@
package mj
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/base64"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"github.com/imroc/req/v3"
"io"
)
// ProxyClient MidJourney Proxy Client
type ProxyClient struct {
Config types.MjProxyConfig
apiURL string
}
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
return &ProxyClient{Config: config, apiURL: config.ApiURL}
}
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
if task.NegPrompt != "" {
prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
}
body := ImageReq{
Prompt: prompt,
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
imageData, err := utils.DownloadImage(task.ImgArr[0], "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
logger.Info("API URL: ", apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
errStr, _ := io.ReadAll(r.Body)
return ImageRes{}, fmt.Errorf("API 返回错误:%s%v", errRes.Error.Message, string(errStr))
}
return res, nil
}
// Blend 融图
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
body := ImageReq{
Dimensions: "SQUARE",
Base64Array: make([]string, 0),
}
// 生成图片 Base64 编码
if len(task.ImgArr) > 0 {
for _, imgURL := range task.ImgArr {
imageData, err := utils.DownloadImage(imgURL, "")
if err != nil {
logger.Error("error with download image: ", err)
} else {
body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
}
}
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// SwapFace 换脸
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能请使用 MidJourney-Plus")
}
// Upscale 放大指定的图片
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "UPSCALE",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
body := map[string]interface{}{
"action": "VARIATION",
"index": task.Index,
"taskId": task.MessageId,
}
apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("mj-api-secret", c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}
var _ Client = &ProxyClient{}

View File

@ -11,10 +11,11 @@ import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
"geekai/service" "geekai/service"
"geekai/service/sd" "geekai/service/oss"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
"github.com/go-redis/redis/v8"
"strings" "strings"
"time" "time"
@ -23,32 +24,29 @@ import (
// Service MJ 绘画服务 // Service MJ 绘画服务
type Service struct { type Service struct {
Name string // service Name client *Client // MJ Client
Client Client // MJ Client
taskQueue *store.RedisQueue taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue notifyQueue *store.RedisQueue
db *gorm.DB db *gorm.DB
running bool Clients *types.LMap[uint, *types.WsClient] // UserId => Client
retryCount map[uint]int uploaderManager *oss.UploaderManager
} }
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager) *Service {
return &Service{ return &Service{
Name: name,
db: db, db: db,
taskQueue: taskQueue, taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
notifyQueue: notifyQueue, notifyQueue: store.NewRedisQueue("MidJourney_Notify_Queue", redisCli),
Client: cli, client: client,
running: true, Clients: types.NewLMap[uint, *types.WsClient](),
retryCount: make(map[uint]int), uploaderManager: manager,
} }
} }
const failedProgress = 101
func (s *Service) Run() { func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name) logger.Info("Starting MidJourney job consumer for service")
for s.running { go func() {
for {
var task types.MjTask var task types.MjTask
err := s.taskQueue.LPop(&task) err := s.taskQueue.LPop(&task)
if err != nil { if err != nil {
@ -56,23 +54,9 @@ func (s *Service) Run() {
continue continue
} }
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
if s.retryCount[task.Id] > 5 {
s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
continue
}
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
s.retryCount[task.Id]++
time.Sleep(time.Second)
continue
}
// translate prompt // translate prompt
if utils.HasChinese(task.Prompt) { if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), "gpt-4o-mini")
if err == nil { if err == nil {
task.Prompt = content task.Prompt = content
} else { } else {
@ -81,7 +65,7 @@ func (s *Service) Run() {
} }
// translate negative prompt // translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini") content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), "gpt-4o-mini")
if err == nil { if err == nil {
task.NegPrompt = content task.NegPrompt = content
} else { } else {
@ -89,6 +73,11 @@ func (s *Service) Run() {
} }
} }
// use fast mode as default
if task.Mode == "" {
task.Mode = "fast"
}
var job model.MidJourneyJob var job model.MidJourneyJob
tx := s.db.Where("id = ?", task.Id).First(&job) tx := s.db.Where("id = ?", task.Id).First(&job)
if tx.Error != nil { if tx.Error != nil {
@ -96,23 +85,23 @@ func (s *Service) Run() {
continue continue
} }
logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) logger.Infof("handle a new MidJourney task: %+v", task)
var res ImageRes var res ImageRes
switch task.Type { switch task.Type {
case types.TaskImage: case types.TaskImage:
res, err = s.Client.Imagine(task) res, err = s.client.Imagine(task)
break break
case types.TaskUpscale: case types.TaskUpscale:
res, err = s.Client.Upscale(task) res, err = s.client.Upscale(task)
break break
case types.TaskVariation: case types.TaskVariation:
res, err = s.Client.Variation(task) res, err = s.client.Variation(task)
break break
case types.TaskBlend: case types.TaskBlend:
res, err = s.Client.Blend(task) res, err = s.client.Blend(task)
break break
case types.TaskSwapFace: case types.TaskSwapFace:
res, err = s.Client.SwapFace(task) res, err = s.client.SwapFace(task)
break break
} }
@ -125,25 +114,22 @@ func (s *Service) Run() {
} }
logger.Error("绘画任务执行失败:", errMsg) logger.Error("绘画任务执行失败:", errMsg)
job.Progress = failedProgress job.Progress = service.FailTaskProgress
job.ErrMsg = errMsg job.ErrMsg = errMsg
// update the task progress // update the task progress
s.db.Updates(&job) s.db.Updates(&job)
// 任务失败,通知前端 // 任务失败,通知前端
s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue continue
} }
logger.Infof("任务提交成功:%+v", res) logger.Infof("任务提交成功:%+v", res)
// 更新任务 ID/频道 // 更新任务 ID/频道
job.TaskId = res.Result job.TaskId = res.Result
job.MessageId = res.Result job.MessageId = res.Result
job.ChannelId = s.Name job.ChannelId = res.Channel
s.db.Updates(&job) s.db.Updates(&job)
} }
} }()
func (s *Service) Stop() {
s.running = false
} }
type CBReq struct { type CBReq struct {
@ -164,20 +150,122 @@ type CBReq struct {
} `json:"properties"` } `json:"properties"`
} }
func (s *Service) Notify(job model.MidJourneyJob) error { func GetImageHash(action string) string {
task, err := s.Client.QueryTask(job.TaskId) split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}
func (s *Service) CheckTaskNotify() {
go func() {
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil { if err != nil {
return err continue
}
cli := s.Clients.Get(uint(message.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.MidJourneyJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := s.uploaderManager.GetUploadHandler().PutUrlFile(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue
} else {
logger.Infof("download image %s successfully.", v.OrgURL)
}
v.ImgURL = imgURL
s.db.Updates(&v)
cli := s.Clients.Get(uint(v.UserId))
if cli == nil {
continue
}
err = cli.Send([]byte(service.TaskStatusFinished))
if err != nil {
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.MjTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
// SyncTaskProgress 异步拉取任务
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
res := s.db.Where("progress < ?", 100).Where("channel_id <> ?", "").Find(&jobs)
if res.Error != nil {
continue
}
for _, job := range jobs {
// 10 分钟还没完成的任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
continue
}
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
if err != nil {
logger.Errorf("error with query task: %v", err)
continue
} }
// 任务执行失败了 // 任务执行失败了
if task.FailReason != "" { if task.FailReason != "" {
s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
"progress": failedProgress, "progress": service.FailTaskProgress,
"err_msg": task.FailReason, "err_msg": task.FailReason,
}) })
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) logger.Errorf("task failed: %v", task.FailReason)
return fmt.Errorf("task failed: %v", task.FailReason) s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
continue
} }
if len(task.Buttons) > 0 { if len(task.Buttons) > 0 {
@ -189,25 +277,23 @@ func (s *Service) Notify(job model.MidJourneyJob) error {
if task.ImageUrl != "" { if task.ImageUrl != "" {
job.OrgURL = task.ImageUrl job.OrgURL = task.ImageUrl
} }
tx := s.db.Updates(&job) err = s.db.Updates(&job).Error
if tx.Error != nil { if err != nil {
return fmt.Errorf("error with update database: %v", tx.Error) logger.Errorf("error with update database: %v", err)
continue
} }
// 通知前端更新任务进度 // 通知前端更新任务进度
if oldProgress != job.Progress { if oldProgress != job.Progress {
message := sd.Running message := service.TaskStatusRunning
if job.Progress == 100 { if job.Progress == 100 {
message = sd.Finished message = service.TaskStatusFinished
}
s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
} }
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
} }
return nil
}
func GetImageHash(action string) string { time.Sleep(time.Second * 5)
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
} }
return split[len(split)-1] }()
} }

View File

@ -10,6 +10,7 @@ package sd
import ( import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
"geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
@ -79,7 +80,7 @@ func (p *ServicePool) CheckTaskNotify() {
go func() { go func() {
logger.Info("Running Stable-Diffusion task notify checking ...") logger.Info("Running Stable-Diffusion task notify checking ...")
for { for {
var message NotifyMessage var message service.NotifyMessage
err := p.notifyQueue.LPop(&message) err := p.notifyQueue.LPop(&message)
if err != nil { if err != nil {
continue continue

View File

@ -10,6 +10,7 @@ package sd
import ( import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger"
"geekai/service" "geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/store" "geekai/store"
@ -22,6 +23,8 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var logger = logger2.GetLogger()
// SD 绘画服务 // SD 绘画服务
type Service struct { type Service struct {
@ -87,11 +90,11 @@ func (s *Service) Run() {
logger.Error("绘画任务执行失败:", err.Error()) logger.Error("绘画任务执行失败:", err.Error())
// update the task progress // update the task progress
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{
"progress": 101, "progress": service.FailTaskProgress,
"err_msg": err.Error(), "err_msg": err.Error(),
}) })
// 通知前端,任务失败 // 通知前端,任务失败
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFailed})
continue continue
} }
} }
@ -206,7 +209,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
// task finished // task finished
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100)
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusFinished})
// 从 leveldb 中删除预览图片数据 // 从 leveldb 中删除预览图片数据
_ = s.leveldb.Delete(task.Params.TaskId) _ = s.leveldb.Delete(task.Params.TaskId)
return nil return nil
@ -216,7 +219,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
if err == nil && resp.Progress > 0 { if err == nil && resp.Progress > 0 {
s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100))
// 发送更新状态信号 // 发送更新状态信号
s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) s.notifyQueue.RPush(service.NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: service.TaskStatusRunning})
// 保存预览图片数据 // 保存预览图片数据
if resp.CurrentImage != "" { if resp.CurrentImage != "" {
_ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage)

View File

@ -1,24 +0,0 @@
package sd
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import logger2 "geekai/logger"
var logger = logger2.GetLogger()
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const (
Running = "RUNNING"
Finished = "FINISH"
Failed = "FAIL"
)

View File

@ -13,8 +13,8 @@ import (
"fmt" "fmt"
"geekai/core/types" "geekai/core/types"
logger2 "geekai/logger" logger2 "geekai/logger"
"geekai/service"
"geekai/service/oss" "geekai/service/oss"
"geekai/service/sd"
"geekai/store" "geekai/store"
"geekai/store/model" "geekai/store/model"
"geekai/utils" "geekai/utils"
@ -88,7 +88,7 @@ func (s *Service) Run() {
logger.Errorf("create task with error: %v", err) logger.Errorf("create task with error: %v", err)
s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{ s.db.Model(&model.SunoJob{Id: task.Id}).UpdateColumns(map[string]interface{}{
"err_msg": err.Error(), "err_msg": err.Error(),
"progress": 101, "progress": service.FailTaskProgress,
}) })
continue continue
} }
@ -157,6 +157,9 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
if res.Code != "success" { if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message) return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
} }
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
session.Updates(&apiKey)
res.Channel = apiKey.ApiURL res.Channel = apiKey.ApiURL
return res, nil return res, nil
} }
@ -165,7 +168,7 @@ func (s *Service) CheckTaskNotify() {
go func() { go func() {
logger.Info("Running Suno task notify checking ...") logger.Info("Running Suno task notify checking ...")
for { for {
var message sd.NotifyMessage var message service.NotifyMessage
err := s.notifyQueue.LPop(&message) err := s.notifyQueue.LPop(&message)
if err != nil { if err != nil {
continue continue
@ -210,7 +213,7 @@ func (s *Service) DownloadImages() {
v.AudioURL = audioURL v.AudioURL = audioURL
v.Progress = 100 v.Progress = 100
s.db.Updates(&v) s.db.Updates(&v)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: sd.Finished}) s.notifyQueue.RPush(service.NotifyMessage{UserId: v.UserId, JobId: int(v.Id), Message: service.TaskStatusFinished})
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
@ -278,10 +281,10 @@ func (s *Service) SyncTaskProgress() {
tx.Commit() tx.Commit()
} else if task.Data.FailReason != "" { } else if task.Data.FailReason != "" {
job.Progress = 101 job.Progress = service.FailTaskProgress
job.ErrMsg = task.Data.FailReason job.ErrMsg = task.Data.FailReason
s.db.Updates(&job) s.db.Updates(&job)
s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) s.notifyQueue.RPush(service.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: service.TaskStatusFailed})
} }
} }

View File

@ -1,4 +1,17 @@
package service package service
const FailTaskProgress = 101
const (
TaskStatusRunning = "RUNNING"
TaskStatusFinished = "FINISH"
TaskStatusFailed = "FAIL"
)
type NotifyMessage struct {
UserId int `json:"user_id"`
JobId int `json:"job_id"`
Message string `json:"message"`
}
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@ -1 +1,2 @@
ALTER TABLE `chatgpt_suno_jobs` MODIFY `id` INT AUTO_INCREMENT; ALTER TABLE `chatgpt_suno_jobs` MODIFY `id` INT AUTO_INCREMENT;
ALTER TABLE `chatgpt_mj_jobs` CHANGE `channel_id` `channel_id` VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT '频道ID';

View File

@ -69,18 +69,6 @@ TikaHost = "http://tika:9998"
SubDir = "" SubDir = ""
Domain = "" Domain = ""
[[MjProxyConfigs]]
Enabled = false
ApiURL = "http://geekai-midjourney-proxy:8082"
Mode = ""
ApiKey = "sk-geekmaster"
[[MjPlusConfigs]]
Enabled = false
ApiURL = "https://api.chat-plus.net"
Mode = "fast"
ApiKey = "sk-xxx"
[[SdConfigs]] [[SdConfigs]]
Enabled = false Enabled = false
Model = "" Model = ""

View File

@ -6,7 +6,7 @@
</div> </div>
<div class="chat-item"> <div class="chat-item">
<div class="content" v-html="data.content"></div> <div class="content" v-html="md.render(processContent(data.content))"></div>
<div class="bar" v-if="data.created_at"> <div class="bar" v-if="data.created_at">
<span class="bar-item"><el-icon><Clock/></el-icon> {{ dateFormat(data.created_at) }}</span> <span class="bar-item"><el-icon><Clock/></el-icon> {{ dateFormat(data.created_at) }}</span>
<span class="bar-item">tokens: {{ data.tokens }}</span> <span class="bar-item">tokens: {{ data.tokens }}</span>
@ -17,7 +17,7 @@
content="复制回答" content="复制回答"
placement="bottom" placement="bottom"
> >
<el-icon class="copy-reply" :data-clipboard-text="data.orgContent"> <el-icon class="copy-reply" :data-clipboard-text="data.content">
<DocumentCopy/> <DocumentCopy/>
</el-icon> </el-icon>
</el-tooltip> </el-tooltip>
@ -34,7 +34,7 @@
</el-tooltip> </el-tooltip>
</span> </span>
<span class="bar-item" @click="synthesis(data.orgContent)"> <span class="bar-item" @click="synthesis(data.content)">
<el-tooltip <el-tooltip
class="box-item" class="box-item"
effect="dark" effect="dark"
@ -69,7 +69,7 @@
</div> </div>
<div class="chat-item"> <div class="chat-item">
<div class="content-wrapper"> <div class="content-wrapper">
<div class="content" v-html="data.content"></div> <div class="content" v-html="md.render(processContent(data.content))"></div>
</div> </div>
<div class="bar" v-if="data.created_at"> <div class="bar" v-if="data.created_at">
<span class="bar-item"><el-icon><Clock/></el-icon> {{ dateFormat(data.created_at) }}</span> <span class="bar-item"><el-icon><Clock/></el-icon> {{ dateFormat(data.created_at) }}</span>
@ -81,7 +81,7 @@
content="复制回答" content="复制回答"
placement="bottom" placement="bottom"
> >
<el-icon class="copy-reply" :data-clipboard-text="data.orgContent"> <el-icon class="copy-reply" :data-clipboard-text="data.content">
<DocumentCopy/> <DocumentCopy/>
</el-icon> </el-icon>
</el-tooltip> </el-tooltip>
@ -98,7 +98,7 @@
</el-tooltip> </el-tooltip>
</span> </span>
<span class="bar-item bg" @click="synthesis(data.orgContent)"> <span class="bar-item bg" @click="synthesis(data.content)">
<el-tooltip <el-tooltip
class="box-item" class="box-item"
effect="dark" effect="dark"
@ -118,7 +118,8 @@
<script setup> <script setup>
import {Clock, DocumentCopy, Refresh} from "@element-plus/icons-vue"; import {Clock, DocumentCopy, Refresh} from "@element-plus/icons-vue";
import {ElMessage} from "element-plus"; import {ElMessage} from "element-plus";
import {dateFormat} from "@/utils/libs"; import {dateFormat, processContent} from "@/utils/libs";
import hl from "highlight.js";
// eslint-disable-next-line no-undef,no-unused-vars // eslint-disable-next-line no-undef,no-unused-vars
const props = defineProps({ const props = defineProps({
data: { data: {
@ -128,7 +129,6 @@ const props = defineProps({
content: "", content: "",
created_at: "", created_at: "",
tokens: 0, tokens: 0,
orgContent: ""
}, },
}, },
readOnly: { readOnly: {
@ -141,6 +141,33 @@ const props = defineProps({
}, },
}) })
const mathjaxPlugin = require('markdown-it-mathjax3')
const md = require('markdown-it')({
breaks: true,
html: true,
linkify: true,
typographer: true,
highlight: function (str, lang) {
const codeIndex = parseInt(Date.now()) + Math.floor(Math.random() * 10000000)
//
const copyBtn = `<span class="copy-code-btn" data-clipboard-action="copy" data-clipboard-target="#copy-target-${codeIndex}">复制</span>
<textarea style="position: absolute;top: -9999px;left: -9999px;z-index: -9999;" id="copy-target-${codeIndex}">${str.replace(/<\/textarea>/g, '&lt;/textarea>')}</textarea>`
if (lang && hl.getLanguage(lang)) {
const langHtml = `<span class="lang-name">${lang}</span>`
//
const preCode = hl.highlight(lang, str, true).value
// pre
return `<pre class="code-container"><code class="language-${lang} hljs">${preCode}</code>${copyBtn} ${langHtml}</pre>`
}
//
const preCode = md.utils.escapeHtml(str)
// pre
return `<pre class="code-container"><code class="language-${lang} hljs">${preCode}</code>${copyBtn}</pre>`
}
});
md.use(mathjaxPlugin)
const emits = defineEmits(['regen']); const emits = defineEmits(['regen']);
if (!props.data.icon) { if (!props.data.icon) {

View File

@ -6,8 +6,6 @@
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import {createRouter, createWebHistory} from "vue-router"; import {createRouter, createWebHistory} from "vue-router";
import {ref} from "vue";
import {httpGet} from "@/utils/http";
const routes = [ const routes = [
{ {

View File

@ -204,13 +204,11 @@ import {Delete, Edit, More, Plus, Promotion, Search, Share, VideoPause} from '@e
import 'highlight.js/styles/a11y-dark.css' import 'highlight.js/styles/a11y-dark.css'
import { import {
isMobile, isMobile,
processContent,
randString, randString,
removeArrayItem, removeArrayItem,
UUID UUID
} from "@/utils/libs"; } from "@/utils/libs";
import {ElMessage, ElMessageBox} from "element-plus"; import {ElMessage, ElMessageBox} from "element-plus";
import hl from "highlight.js";
import {getSessionId, getUserToken, removeUserToken} from "@/store/session"; import {getSessionId, getUserToken, removeUserToken} from "@/store/session";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";
@ -221,7 +219,6 @@ import {useSharedStore} from "@/store/sharedata";
import FileSelect from "@/components/FileSelect.vue"; import FileSelect from "@/components/FileSelect.vue";
import FileList from "@/components/FileList.vue"; import FileList from "@/components/FileList.vue";
import ChatSetting from "@/components/ChatSetting.vue"; import ChatSetting from "@/components/ChatSetting.vue";
import axios from "axios";
import BackTop from "@/components/BackTop.vue"; import BackTop from "@/components/BackTop.vue";
import {showMessageError} from "@/utils/dialog"; import {showMessageError} from "@/utils/dialog";
@ -547,33 +544,6 @@ const removeChat = function (chat) {
} }
const mathjaxPlugin = require('markdown-it-mathjax3')
const md = require('markdown-it')({
breaks: true,
html: true,
linkify: true,
typographer: true,
highlight: function (str, lang) {
const codeIndex = parseInt(Date.now()) + Math.floor(Math.random() * 10000000)
//
const copyBtn = `<span class="copy-code-btn" data-clipboard-action="copy" data-clipboard-target="#copy-target-${codeIndex}">复制</span>
<textarea style="position: absolute;top: -9999px;left: -9999px;z-index: -9999;" id="copy-target-${codeIndex}">${str.replace(/<\/textarea>/g, '&lt;/textarea>')}</textarea>`
if (lang && hl.getLanguage(lang)) {
const langHtml = `<span class="lang-name">${lang}</span>`
//
const preCode = hl.highlight(lang, str, true).value
// pre
return `<pre class="code-container"><code class="language-${lang} hljs">${preCode}</code>${copyBtn} ${langHtml}</pre>`
}
//
const preCode = md.utils.escapeHtml(str)
// pre
return `<pre class="code-container"><code class="language-${lang} hljs">${preCode}</code>${copyBtn}</pre>`
}
});
md.use(mathjaxPlugin)
// socket // socket
const prompt = ref(''); const prompt = ref('');
const showStopGenerate = ref(false); // const showStopGenerate = ref(false); //
@ -622,7 +592,6 @@ const connect = function (chat_id, role_id) {
id: randString(32), id: randString(32),
icon: _role['icon'], icon: _role['icon'],
content: _role['hello_msg'], content: _role['hello_msg'],
orgContent: _role['hello_msg'],
}) })
ElMessage.success({message: "对话连接成功!", duration: 1000}) ElMessage.success({message: "对话连接成功!", duration: 1000})
} else { // } else { //
@ -645,7 +614,6 @@ const connect = function (chat_id, role_id) {
icon: _role['icon'], icon: _role['icon'],
prompt:prePrompt, prompt:prePrompt,
content: "", content: "",
orgContent: "",
}); });
} else if (data.type === 'end') { // } else if (data.type === 'end') { //
// //
@ -680,8 +648,7 @@ const connect = function (chat_id, role_id) {
lineBuffer.value += data.content; lineBuffer.value += data.content;
const reply = chatData.value[chatData.value.length - 1] const reply = chatData.value[chatData.value.length - 1]
if (reply) { if (reply) {
reply['orgContent'] = lineBuffer.value; reply['content'] = lineBuffer.value;
reply['content'] = md.render(processContent(lineBuffer.value));
} }
} }
// //
@ -845,12 +812,8 @@ const loadChatHistory = function (chatId) {
} }
showHello.value = false showHello.value = false
for (let i = 0; i < data.length; i++) { for (let i = 0; i < data.length; i++) {
data[i].orgContent = data[i].content; if (data[i].type === 'reply' && i > 0) {
if (data[i].type === 'reply') { data[i].prompt = data[i - 1].content
data[i].content = md.render(processContent(data[i].content))
if (i > 0) {
data[i].prompt = data[i - 1].orgContent
}
} }
chatData.value.push(data[i]); chatData.value.push(data[i]);
} }

View File

@ -1,173 +0,0 @@
<template>
<el-form label-width="150px" label-position="right" class="draw-config">
<el-tabs type="border-card">
<el-tab-pane label="MJ-PLUS">
<div v-if="mjPlusConfigs">
<div class="config-item" v-for="(v,k) in mjPlusConfigs">
<el-form-item label="是否启用">
<el-switch v-model="v['Enabled']"/>
</el-form-item>
<el-form-item label="API 地址">
<el-input v-model="v['ApiURL']" placeholder="API 地址"/>
</el-form-item>
<el-form-item label="API 令牌">
<el-input v-model="v['ApiKey']" placeholder="API KEY"/>
</el-form-item>
<el-form-item label="绘画模式">
<el-select v-model="v['Mode']" placeholder="请选择模式">
<el-option v-for="item in mjModels" :value="item.value" :label="item.name" :key="item.value">{{
item.name
}}
</el-option>
</el-select>
</el-form-item>
<el-button class="remove" type="danger" :icon="Delete" circle @click="removeItem(mjPlusConfigs,k)"/>
</div>
</div>
<el-empty v-else></el-empty>
<el-row style="justify-content: center; padding: 10px">
<el-button round @click="addConfig(mjPlusConfigs)">
<el-icon><Plus /></el-icon>
<span>新增配置</span>
</el-button>
</el-row>
</el-tab-pane>
<el-tab-pane label="MJ-PROXY">
<div v-if="mjProxyConfigs">
<div class="config-item" v-for="(v,k) in mjProxyConfigs">
<el-form-item label="是否启用">
<el-switch v-model="v['Enabled']"/>
</el-form-item>
<el-form-item label="API 地址">
<el-input v-model="v['ApiURL']" placeholder="API 地址"/>
</el-form-item>
<el-form-item label="API 令牌">
<el-input v-model="v['ApiKey']" placeholder="API KEY"/>
</el-form-item>
<el-button class="remove" type="danger" :icon="Delete" circle @click="removeItem(mjProxyConfigs,k)"/>
</div>
</div>
<el-empty v-else />
<el-row style="justify-content: center; padding: 10px">
<el-button round @click="addConfig(mjProxyConfigs)">
<el-icon>
<Plus/>
</el-icon>
<span>新增配置</span>
</el-button>
</el-row>
</el-tab-pane>
<el-tab-pane label="Stable-Diffusion">
<div v-if="sdConfigs">
<div class="config-item" v-for="(v,k) in sdConfigs">
<el-form-item label="是否启用">
<el-switch v-model="v['Enabled']"/>
</el-form-item>
<el-form-item label="API 地址">
<el-input v-model="v['ApiURL']" placeholder="API 地址"/>
</el-form-item>
<el-form-item label="API 令牌">
<el-input v-model="v['ApiKey']" placeholder="API KEY"/>
</el-form-item>
<el-form-item label="模型">
<el-input v-model="v['Model']" placeholder="绘画模型"/>
</el-form-item>
<el-button class="remove" type="danger" :icon="Delete" circle @click="removeItem(sdConfigs,k)"/>
</div>
</div>
<el-empty v-else/>
<el-row style="justify-content: center; padding: 10px">
<el-button round @click="addConfig(sdConfigs)">
<el-icon>
<Plus/>
</el-icon>
<span>新增配置</span>
</el-button>
</el-row>
</el-tab-pane>
</el-tabs>
<div style="padding: 10px;">
<el-form-item>
<el-button type="primary" @click="saveConfig()">保存</el-button>
</el-form-item>
</div>
</el-form>
</template>
<script setup>
import {ref} from "vue";
import {httpGet, httpPost} from "@/utils/http";
import {ElMessage} from "element-plus";
import {Delete, Plus} from "@element-plus/icons-vue";
//
const sdConfigs = ref([])
const mjPlusConfigs = ref([])
const mjProxyConfigs = ref([])
const mjModels = ref([
{name: "慢速Relax", value: "relax"},
{name: "快速Fast", value: "fast"},
{name: "急速Turbo", value: "turbo"},
])
httpGet("/api/admin/config/get/app").then(res => {
sdConfigs.value = res.data.sd
mjPlusConfigs.value = res.data.mj_plus
mjProxyConfigs.value = res.data.mj_proxy
}).catch(e =>{
ElMessage.error("获取配置失败:"+e.message)
})
const addConfig = (configs) => {
configs.push({
Enabled: true,
ApiKey: '',
ApiURL: '',
Mode: 'fast'
})
}
const saveConfig = () => {
httpPost('/api/admin/config/update/draw', {
'sd': sdConfigs.value,
'mj_plus': mjPlusConfigs.value,
'mj_proxy': mjProxyConfigs.value
}).then(() => {
ElMessage.success("配置更新成功")
}).catch(e => {
ElMessage.error("操作失败:" + e.message)
})
}
const removeItem = (arr, k) => {
arr.splice(k, 1)
}
</script>
<style lang="stylus" scoped>
.draw-config {
.config-item {
position relative
padding 15px 10px 10px 10px
border 1px solid var(--el-border-color)
border-radius 10px
margin-bottom 10px
.remove {
position absolute
right 15px
top 15px
}
}
}
</style>

View File

@ -139,6 +139,7 @@ const title = ref("")
const types = ref([ const types = ref([
{label: "对话", value:"chat"}, {label: "对话", value:"chat"},
{label: "Midjourney", value:"mj"}, {label: "Midjourney", value:"mj"},
{label: "Stable-Diffusion", value:"sd"},
{label: "DALL-E", value:"dalle"}, {label: "DALL-E", value:"dalle"},
{label: "Suno文生歌", value:"suno"}, {label: "Suno文生歌", value:"suno"},
{label: "Luma视频", value:"luma"}, {label: "Luma视频", value:"luma"},

View File

@ -157,11 +157,7 @@
<div v-for="item in messages" :key="item.id"> <div v-for="item in messages" :key="item.id">
<chat-prompt <chat-prompt
v-if="item.type==='prompt'" v-if="item.type==='prompt'"
:icon="item.icon" :data="item"/>
:created-at="dateFormat(item['created_at'])"
:tokens="item['tokens']"
:model="item.model"
:content="item.content"/>
<chat-reply v-else-if="item.type==='reply'" <chat-reply v-else-if="item.type==='reply'"
:read-only="true" :read-only="true"
:data="item"/> :data="item"/>
@ -288,33 +284,11 @@ const removeMessage = function (row) {
}) })
} }
const mathjaxPlugin = require('markdown-it-mathjax3')
const md = require('markdown-it')({
breaks: true,
html: true,
linkify: true,
typographer: true,
highlight: function (str, lang) {
if (lang && hl.getLanguage(lang)) {
//
const preCode = hl.highlight(lang, str, true).value
// pre
return `<pre class="code-container"><code class="language-${lang} hljs">${preCode}</code></pre>`
}
//
const preCode = md.utils.escapeHtml(str)
// pre
return `<pre class="code-container"><code class="language-${lang} hljs">${preCode}</code></pre>`
}
});
md.use(mathjaxPlugin)
const showContentDialog = ref(false) const showContentDialog = ref(false)
const dialogContent = ref("") const dialogContent = ref("")
const showContent = (content) => { const showContent = (content) => {
showContentDialog.value = true showContentDialog.value = true
dialogContent.value = md.render(processContent(content)) dialogContent.value = processContent(content)
} }
const showChatItemDialog = ref(false) const showChatItemDialog = ref(false)
@ -325,8 +299,6 @@ const showMessages = (row) => {
httpGet('/api/admin/chat/history?chat_id=' + row.chat_id).then(res => { httpGet('/api/admin/chat/history?chat_id=' + row.chat_id).then(res => {
const data = res.data const data = res.data
for (let i = 0; i < data.length; i++) { for (let i = 0; i < data.length; i++) {
data[i].orgContent = data[i].content;
data[i].content = md.render(processContent(data[i].content))
messages.value.push(data[i]); messages.value.push(data[i]);
} }
}).catch(e => { }).catch(e => {

View File

@ -194,6 +194,15 @@
</div> </div>
</div> </div>
</el-form-item> </el-form-item>
<el-form-item label="MJ默认API模式" prop="mj_mode">
<el-select v-model="system['mj_mode']" placeholder="请选择模式">
<el-option v-for="item in mjModels" :value="item.value" :label="item.name" :key="item.value">{{
item.name
}}
</el-option>
</el-select>
</el-form-item>
</el-tab-pane> </el-tab-pane>
<el-tab-pane label="算力配置"> <el-tab-pane label="算力配置">
@ -359,10 +368,6 @@
<Menu/> <Menu/>
</el-tab-pane> </el-tab-pane>
<el-tab-pane label="AI绘图配置" name="AIDrawing">
<AIDrawing/>
</el-tab-pane>
<el-tab-pane label="授权激活" name="license"> <el-tab-pane label="授权激活" name="license">
<div class="container"> <div class="container">
<el-descriptions <el-descriptions
@ -431,7 +436,6 @@ import MdEditor from "md-editor-v3";
import 'md-editor-v3/lib/style.css'; import 'md-editor-v3/lib/style.css';
import Menu from "@/views/admin/Menu.vue"; import Menu from "@/views/admin/Menu.vue";
import {copyObj, dateFormat} from "@/utils/libs"; import {copyObj, dateFormat} from "@/utils/libs";
import AIDrawing from "@/views/admin/AIDrawing.vue";
const activeName = ref('basic') const activeName = ref('basic')
const system = ref({models: []}) const system = ref({models: []})
@ -439,10 +443,14 @@ const configBak = ref({})
const loading = ref(true) const loading = ref(true)
const systemFormRef = ref(null) const systemFormRef = ref(null)
const models = ref([]) const models = ref([])
const openAIModels = ref([])
const notice = ref("") const notice = ref("")
const license = ref({is_active: false}) const license = ref({is_active: false})
const menus = ref([]) const menus = ref([])
const mjModels = ref([
{name: "慢速Relax", value: "relax"},
{name: "快速Fast", value: "fast"},
{name: "急速Turbo", value: "turbo"},
])
onMounted(() => { onMounted(() => {
// //
@ -461,7 +469,6 @@ onMounted(() => {
httpGet('/api/admin/model/list').then(res => { httpGet('/api/admin/model/list').then(res => {
models.value = res.data models.value = res.data
openAIModels.value = models.value.filter(v => v.platform === "OpenAI")
loading.value = false loading.value = false
}).catch(e => { }).catch(e => {
ElMessage.error("获取模型失败:" + e.message) ElMessage.error("获取模型失败:" + e.message)