feat: refactor MidJourney service for conpatible drawing in chat and draw in app

This commit is contained in:
RockYang
2023-09-12 18:01:24 +08:00
parent 036a6e3e41
commit fa341bab30
12 changed files with 467 additions and 262 deletions

View File

@@ -0,0 +1,64 @@
package function
import (
"chatplus/service"
"chatplus/utils"
"fmt"
)
// AI 绘画函数
type FuncMidJourney struct {
name string
service *service.MjService
}
func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
return FuncMidJourney{
name: "MidJourney AI 绘画",
service: mjService}
}
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"])
if !utils.IsEmptyValue(params["--ar"]) {
prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
delete(params, "--ar")
}
if !utils.IsEmptyValue(params["--s"]) {
prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
delete(params, "--s")
}
if !utils.IsEmptyValue(params["--seed"]) {
prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
delete(params, "--seed")
}
if !utils.IsEmptyValue(params["--no"]) {
prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
delete(params, "--no")
}
if !utils.IsEmptyValue(params["--niji"]) {
prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
delete(params, "--niji")
} else {
prompt = prompt + " --v 5.2"
}
f.service.PushTask(service.MjTask{
Id: utils.InterfaceToString(params["session_id"]),
Src: service.TaskSrcChat,
Prompt: prompt,
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
Icon: utils.InterfaceToString(params["icon"]),
ChatId: utils.InterfaceToString(params["chat_id"]),
})
return prompt, nil
}
func (f FuncMidJourney) Name() string {
return f.name
}
var _ Function = &FuncMidJourney{}

View File

@@ -1,129 +0,0 @@
package function
import (
"chatplus/core/types"
"chatplus/utils"
"errors"
"fmt"
"github.com/imroc/req/v3"
"time"
)
// AI 绘画函数
type FuncMidJourney struct {
name string
config types.ChatPlusExtConfig
client *req.Client
}
func NewMidJourneyFunc(config types.ChatPlusExtConfig) FuncMidJourney {
return FuncMidJourney{
name: "MidJourney AI 绘画",
config: config,
client: req.C().SetTimeout(30 * time.Second)}
}
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
if f.config.Token == "" {
return "", errors.New("无效的 API Token")
}
logger.Infof("MJ 绘画参数:%+v", params)
prompt := utils.InterfaceToString(params["prompt"])
if !utils.IsEmptyValue(params["--ar"]) {
prompt = fmt.Sprintf("%s --ar %s", prompt, params["--ar"])
delete(params, "--ar")
}
if !utils.IsEmptyValue(params["--s"]) {
prompt = fmt.Sprintf("%s --s %s", prompt, params["--s"])
delete(params, "--s")
}
if !utils.IsEmptyValue(params["--seed"]) {
prompt = fmt.Sprintf("%s --seed %s", prompt, params["--seed"])
delete(params, "--seed")
}
if !utils.IsEmptyValue(params["--no"]) {
prompt = fmt.Sprintf("%s --no %s", prompt, params["--no"])
delete(params, "--no")
}
if !utils.IsEmptyValue(params["--niji"]) {
prompt = fmt.Sprintf("%s --niji %s", prompt, params["--niji"])
delete(params, "--niji")
} else {
prompt = prompt + " --v 5.2"
}
params["prompt"] = prompt
url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL)
var res types.BizVo
r, err := f.client.R().
SetHeader("Authorization", f.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(params).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return "", errors.New(res.Message)
}
return prompt, nil
}
type MjUpscaleReq struct {
Index int32 `json:"index"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
}
func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error {
url := fmt.Sprintf("%s/api/mj/upscale", f.config.ApiURL)
var res types.BizVo
r, err := f.client.R().
SetHeader("Authorization", f.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(upReq).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}
type MjVariationReq struct {
Index int32 `json:"index"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
}
func (f FuncMidJourney) Variation(upReq MjVariationReq) error {
url := fmt.Sprintf("%s/api/mj/variation", f.config.ApiURL)
var res types.BizVo
r, err := f.client.R().
SetHeader("Authorization", f.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(upReq).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}
func (f FuncMidJourney) Name() string {
return f.name
}
var _ Function = &FuncMidJourney{}

189
api/service/mj_service.go Normal file
View File

@@ -0,0 +1,189 @@
package service
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store"
"chatplus/utils"
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
"time"
)
var logger = logger2.GetLogger()
// MJ 绘画服务
const MjRunningJobKey = "MidJourney_Running_Job"
type TaskType string
const (
Image = TaskType("image")
Upscale = TaskType("upscale")
Variation = TaskType("variation")
)
type TaskSrc string
const (
TaskSrcChat = TaskSrc("chat")
TaskSrcImg = TaskSrc("img")
)
type MjTask struct {
Id string `json:"id"`
Src TaskSrc `json:"src"`
Type TaskType `json:"type"`
UserId int `json:"user_id"`
Prompt string `json:"prompt,omitempty"`
ChatId string `json:"chat_id,omitempty"`
RoleId int `json:"role_id,omitempty"`
Icon string `json:"icon,omitempty"`
Index int32 `json:"index,omitempty"`
MessageId string `json:"message_id,omitempty"`
MessageHash string `json:"message_hash,omitempty"`
RetryCount int `json:"retry_count"`
}
type MjService struct {
config types.ChatPlusExtConfig
client *req.Client
taskQueue *store.RedisQueue
redis *redis.Client
}
func NewMjService(config types.ChatPlusExtConfig, client *redis.Client) *MjService {
return &MjService{
config: config,
redis: client,
taskQueue: store.NewRedisQueue("midjourney_task_queue", client),
client: req.C().SetTimeout(30 * time.Second)}
}
func (s *MjService) Run() {
ctx := context.Background()
for {
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
if err == nil { // a task is running, waiting for finish
time.Sleep(time.Second * 3)
continue
}
var task MjTask
err = s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
switch task.Type {
case Image:
err = s.image(task.Prompt)
break
case Upscale:
err = s.upscale(MjUpscaleReq{
Index: task.Index,
MessageId: task.MessageId,
MessageHash: task.MessageHash,
})
break
case Variation:
err = s.variation(MjVariationReq{
Index: task.Index,
MessageId: task.MessageId,
MessageHash: task.MessageHash,
})
}
if err != nil {
if task.RetryCount > 5 {
continue
}
task.RetryCount += 1
time.Sleep(time.Second)
s.taskQueue.RPush(task)
// TODO: 执行失败通知聊天客户端
continue
}
// 锁定任务执行通道直到任务超时10分钟
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Second*600)
}
}
func (s *MjService) PushTask(task MjTask) {
s.taskQueue.RPush(task)
}
func (s *MjService) image(prompt string) error {
logger.Infof("MJ 绘画参数:%+v", prompt)
body := map[string]string{"prompt": prompt}
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("Authorization", s.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(body).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}
type MjUpscaleReq struct {
Index int32 `json:"index"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
}
func (s *MjService) upscale(upReq MjUpscaleReq) error {
url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("Authorization", s.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(upReq).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}
type MjVariationReq struct {
Index int32 `json:"index"`
MessageId string `json:"message_id"`
MessageHash string `json:"message_hash"`
}
func (s *MjService) variation(upReq MjVariationReq) error {
url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("Authorization", s.config.Token).
SetHeader("Content-Type", "application/json").
SetBody(upReq).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return fmt.Errorf("%v%v", r.String(), err)
}
if res.Code != types.Success {
return errors.New(res.Message)
}
return nil
}