mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
refactor: add midjourney pool implementation, add translate prompt for mj drawing
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
@@ -32,16 +31,14 @@ var logger = logger2.GetLogger()
|
||||
|
||||
type ChatHandler struct {
|
||||
handler.BaseHandler
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
mjService *mj.Service
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
}
|
||||
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, service *mj.Service) *ChatHandler {
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client) *ChatHandler {
|
||||
h := ChatHandler{
|
||||
db: db,
|
||||
redis: redis,
|
||||
mjService: service,
|
||||
db: db,
|
||||
redis: redis,
|
||||
}
|
||||
h.App = app
|
||||
return &h
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,26 +19,22 @@ import (
|
||||
|
||||
type MidJourneyHandler struct {
|
||||
BaseHandler
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
mjService *mj.Service
|
||||
pool *mj.ServicePool
|
||||
snowflake *service.Snowflake
|
||||
}
|
||||
|
||||
func NewMidJourneyHandler(
|
||||
app *core.AppServer,
|
||||
client *redis.Client,
|
||||
db *gorm.DB,
|
||||
mjService *mj.Service) *MidJourneyHandler {
|
||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool) *MidJourneyHandler {
|
||||
h := MidJourneyHandler{
|
||||
redis: client,
|
||||
db: db,
|
||||
mjService: mjService,
|
||||
snowflake: snowflake,
|
||||
pool: pool,
|
||||
}
|
||||
h.App = app
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
|
||||
user, err := utils.GetLoginUser(c, h.db)
|
||||
if err != nil {
|
||||
resp.NotAuth(c)
|
||||
@@ -50,17 +46,17 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if !h.pool.HasAvailableService() {
|
||||
resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// Image 创建一个绘画任务
|
||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
if !h.App.Config.MjConfigs[0].Enabled {
|
||||
resp.ERROR(c, "MidJourney service is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
SessionId string `json:"session_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -80,7 +76,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if !h.checkLimits(c) {
|
||||
if !h.preCheck(c) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -121,9 +117,16 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
// generate task id
|
||||
taskId, err := h.snowflake.Next(true)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate task id: "+err.Error())
|
||||
return
|
||||
}
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskImage.String(),
|
||||
UserId: userId,
|
||||
TaskId: taskId,
|
||||
Progress: 0,
|
||||
Prompt: prompt,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -133,24 +136,13 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Src: types.TaskSrcImg,
|
||||
Type: types.TaskImage,
|
||||
Prompt: prompt,
|
||||
Prompt: fmt.Sprintf("%s %s", taskId, prompt),
|
||||
UserId: userId,
|
||||
})
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.mjService.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -174,65 +166,23 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !h.checkLimits(c) {
|
||||
if !h.preCheck(c) {
|
||||
return
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
src := types.TaskSrc(data.Src)
|
||||
if src == types.TaskSrcImg {
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskUpscale.String(),
|
||||
UserId: userId,
|
||||
Hash: data.MessageHash,
|
||||
Progress: 0,
|
||||
Prompt: data.Prompt,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.db.Create(&job); res.Error == nil {
|
||||
jobId = int(job.Id)
|
||||
} else {
|
||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.mjService.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Src: src,
|
||||
Type: types.TaskUpscale,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
RoleId: data.RoleId,
|
||||
Icon: data.Icon,
|
||||
ChatId: data.ChatId,
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
})
|
||||
|
||||
if src == types.TaskSrcChat {
|
||||
wsClient := h.App.ChatClients.Get(data.SessionId)
|
||||
if wsClient != nil {
|
||||
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
if h.mjService.ChatClients.Get(data.SessionId) == nil {
|
||||
h.mjService.ChatClients.Put(data.SessionId, wsClient)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -244,67 +194,23 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !h.checkLimits(c) {
|
||||
if !h.preCheck(c) {
|
||||
return
|
||||
}
|
||||
|
||||
idValue, _ := c.Get(types.LoginUserID)
|
||||
jobId := 0
|
||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||
src := types.TaskSrc(data.Src)
|
||||
if src == types.TaskSrcImg {
|
||||
job := model.MidJourneyJob{
|
||||
Type: types.TaskVariation.String(),
|
||||
UserId: userId,
|
||||
ImgURL: "",
|
||||
Hash: data.MessageHash,
|
||||
Progress: 0,
|
||||
Prompt: data.Prompt,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if res := h.db.Create(&job); res.Error == nil {
|
||||
jobId = int(job.Id)
|
||||
} else {
|
||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
// 推送任务到前端
|
||||
client := h.mjService.Clients.Get(data.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mjService.PushTask(types.MjTask{
|
||||
h.pool.PushTask(types.MjTask{
|
||||
Id: jobId,
|
||||
SessionId: data.SessionId,
|
||||
Src: src,
|
||||
Type: types.TaskVariation,
|
||||
Prompt: data.Prompt,
|
||||
UserId: userId,
|
||||
RoleId: data.RoleId,
|
||||
Icon: data.Icon,
|
||||
ChatId: data.ChatId,
|
||||
Index: data.Index,
|
||||
MessageId: data.MessageId,
|
||||
MessageHash: data.MessageHash,
|
||||
})
|
||||
|
||||
if src == types.TaskSrcChat {
|
||||
// 从聊天窗口发送的请求,记录客户端信息
|
||||
wsClient := h.mjService.ChatClients.Get(data.SessionId)
|
||||
if wsClient != nil {
|
||||
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
if h.mjService.Clients.Get(data.SessionId) == nil {
|
||||
h.mjService.Clients.Put(data.SessionId, wsClient)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -343,19 +249,27 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if job.Progress == -1 {
|
||||
h.db.Delete(&model.MidJourneyJob{Id: job.Id})
|
||||
}
|
||||
|
||||
if item.Progress < 100 {
|
||||
// 10 分钟还没完成的任务直接删除
|
||||
if time.Now().Sub(item.CreatedAt) > time.Minute*10 {
|
||||
h.db.Delete(&item)
|
||||
continue
|
||||
}
|
||||
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
||||
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
||||
|
||||
// 正在运行中任务使用代理访问图片
|
||||
if item.ImgURL == "" && item.OrgURL != "" {
|
||||
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
|
||||
if err == nil {
|
||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
resp.SUCCESS(c, jobs)
|
||||
|
||||
@@ -176,7 +176,7 @@ func (h *PaymentHandler) PayQrcode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
orderNo, err := h.snowflake.Next()
|
||||
orderNo, err := h.snowflake.Next(false)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "error with generate trade no: "+err.Error())
|
||||
return
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const translatePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
||||
const rewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
|
||||
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
|
||||
type PromptHandler struct {
|
||||
BaseHandler
|
||||
@@ -47,6 +48,25 @@ type apiErrRes struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Rewrite translate and rewrite prompt with ChatGPT
|
||||
func (h *PromptHandler) Rewrite(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := h.request(data.Prompt, rewritePromptTemplate)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
func (h *PromptHandler) Translate(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -55,18 +75,28 @@ func (h *PromptHandler) Translate(c *gin.Context) {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
|
||||
content, err := h.request(data.Prompt, translatePromptTemplate)
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
func (h *PromptHandler) request(prompt string, promptTemplate string) (string, error) {
|
||||
// 获取 OpenAI 的 API KEY
|
||||
var apiKey model.ApiKey
|
||||
res := h.db.Where("platform = ?", types.OpenAI).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
resp.ERROR(c, "找不到可用 OpenAI API KEY")
|
||||
return
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error)
|
||||
}
|
||||
|
||||
messages := make([]interface{}, 1)
|
||||
messages[0] = types.Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf(translatePromptTemplate, data.Prompt),
|
||||
Content: fmt.Sprintf(promptTemplate, prompt),
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
@@ -83,9 +113,8 @@ func (h *PromptHandler) Translate(c *gin.Context) {
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&response).Post(h.App.ChatConfig.OpenAI.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
resp.ERROR(c, fmt.Sprintf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message))
|
||||
return
|
||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, response.Choices[0].Message.Content)
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
@@ -141,7 +141,6 @@ func (h *SdJobHandler) Image(c *gin.Context) {
|
||||
h.service.PushTask(types.SdTask{
|
||||
Id: int(job.Id),
|
||||
SessionId: data.SessionId,
|
||||
Src: types.TaskSrcImg,
|
||||
Type: types.TaskImage,
|
||||
Prompt: data.Prompt,
|
||||
Params: params,
|
||||
|
||||
Reference in New Issue
Block a user