mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: midjourney image variation function is ready
This commit is contained in:
		@@ -83,15 +83,15 @@ var InnerFunctions = []Function{
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"prompt": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "绘画内容描述,提示词,此参数需要翻译成英文",
 | 
			
		||||
					Description: "绘画内容描述,提示词,如果该参数中有中文的话,则需要翻译成英文",
 | 
			
		||||
				},
 | 
			
		||||
				"ar": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "图片长宽比,如 16:9, --ar 3:2",
 | 
			
		||||
					Description: "图片长宽比,如 --ar 4:3",
 | 
			
		||||
				},
 | 
			
		||||
				"niji": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "动漫模型版本,如 --niji 5",
 | 
			
		||||
					Description: "动漫模型版本,例如 --niji 5",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -83,7 +84,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
			
		||||
		err := h.leveldb.Get(types.TaskStorePrefix+data.Key, &task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with get MidJourney task: ", err)
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			resp.SUCCESS(c)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -138,16 +139,18 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, "SUCCESS")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type reqVo struct {
 | 
			
		||||
	Index       int32  `json:"index"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
	SessionId   string `json:"session_id"`
 | 
			
		||||
	Key         string `json:"key"`
 | 
			
		||||
	Prompt      string `json:"prompt"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale send upscale command to MidJourney Bot
 | 
			
		||||
func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Index       int32  `json:"index"`
 | 
			
		||||
		MessageId   string `json:"message_id"`
 | 
			
		||||
		MessageHash string `json:"message_hash"`
 | 
			
		||||
		SessionId   string `json:"session_id"`
 | 
			
		||||
		Key         string `json:"key"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data reqVo
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil ||
 | 
			
		||||
		data.SessionId == "" ||
 | 
			
		||||
		data.Key == "" {
 | 
			
		||||
@@ -170,7 +173,39 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	utils.ReplyMessage(wsClient, "已推送放大图片任务到 MidJourney 机器人,请耐心等待任务执行...")
 | 
			
		||||
	content := fmt.Sprintf("**%s** 已推送 Upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
	utils.ReplyMessage(wsClient, content)
 | 
			
		||||
	if h.App.MjTaskClients.Get(data.Key) == nil {
 | 
			
		||||
		h.App.MjTaskClients.Put(data.Key, wsClient)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
	var data reqVo
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil ||
 | 
			
		||||
		data.SessionId == "" ||
 | 
			
		||||
		data.Key == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	wsClient := h.App.ChatClients.Get(data.SessionId)
 | 
			
		||||
	if wsClient == nil {
 | 
			
		||||
		resp.ERROR(c, "No Websocket client online")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := h.mjFunc.Variation(function.MjVariationReq{
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	content := fmt.Sprintf("**%s** 已推送 Variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
	utils.ReplyMessage(wsClient, content)
 | 
			
		||||
	if h.App.MjTaskClients.Get(data.Key) == nil {
 | 
			
		||||
		h.App.MjTaskClients.Put(data.Key, wsClient)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -179,6 +179,7 @@ func main() {
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
			
		||||
			s.Engine.POST("/api/mj/notify", h.Notify)
 | 
			
		||||
			s.Engine.POST("/api/mj/upscale", h.Upscale)
 | 
			
		||||
			s.Engine.POST("/api/mj/variation", h.Variation)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 管理后台控制器
 | 
			
		||||
 
 | 
			
		||||
@@ -85,6 +85,31 @@ func (f FuncMidJourney) Upscale(upReq MjUpscaleReq) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjVariationReq struct {
 | 
			
		||||
	Index       int32  `json:"index"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncMidJourney) Variation(upReq MjVariationReq) error {
 | 
			
		||||
	url := fmt.Sprintf("%s/api/mj/variation", f.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("Authorization", f.config.Token).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(upReq).
 | 
			
		||||
		SetSuccessResult(&res).Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncMidJourney) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -92,14 +92,23 @@ watch(() => props.content, (newVal) => {
 | 
			
		||||
});
 | 
			
		||||
const emits = defineEmits(['disable-input', 'disable-input']);
 | 
			
		||||
const upscale = (index) => {
 | 
			
		||||
  send('/api/mj/upscale', index)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const variation = (index) => {
 | 
			
		||||
  send('/api/mj/variation', index)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const send = (url, index) => {
 | 
			
		||||
  loading.value = true
 | 
			
		||||
  emits('disable-input')
 | 
			
		||||
  httpPost("/api/mj/upscale", {
 | 
			
		||||
  httpPost(url, {
 | 
			
		||||
    index: index,
 | 
			
		||||
    message_id: data.value?.["message_id"],
 | 
			
		||||
    message_hash: data.value?.["image"]?.hash,
 | 
			
		||||
    session_id: getSessionId(),
 | 
			
		||||
    key: data.value?.["key"]
 | 
			
		||||
    key: data.value?.["key"],
 | 
			
		||||
    prompt: data.value?.["prompt"],
 | 
			
		||||
  }).then(() => {
 | 
			
		||||
    ElMessage.success("任务推送成功,请耐心等待任务执行...")
 | 
			
		||||
    loading.value = false
 | 
			
		||||
@@ -108,10 +117,6 @@ const upscale = (index) => {
 | 
			
		||||
    emits('disable-input')
 | 
			
		||||
  })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const variation = (index) => {
 | 
			
		||||
  ElMessage.warning("当前版本暂未实现 Variation 功能!")
 | 
			
		||||
}
 | 
			
		||||
</script>
 | 
			
		||||
 | 
			
		||||
<style lang="stylus">
 | 
			
		||||
 
 | 
			
		||||
@@ -554,6 +554,12 @@ const connect = function (chat_id, role_id) {
 | 
			
		||||
          const content = data.content;
 | 
			
		||||
          const md = require('markdown-it')({breaks: true});
 | 
			
		||||
          content.content = md.render(content.content)
 | 
			
		||||
          let key = content.key
 | 
			
		||||
          // fixed bug: 执行 Upscale 和 Variation 操作的时候覆盖之前的绘画
 | 
			
		||||
          if (content.status === "Finished") {
 | 
			
		||||
            key = randString(32)
 | 
			
		||||
            enableInput()
 | 
			
		||||
          }
 | 
			
		||||
          // console.log(content)
 | 
			
		||||
          // check if the message is in chatData
 | 
			
		||||
          let flag = false
 | 
			
		||||
@@ -562,21 +568,19 @@ const connect = function (chat_id, role_id) {
 | 
			
		||||
              console.log(chatData.value[i])
 | 
			
		||||
              flag = true
 | 
			
		||||
              chatData.value[i].content = content
 | 
			
		||||
              chatData.value[i].id = key
 | 
			
		||||
              break
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          if (flag === false) {
 | 
			
		||||
            chatData.value.push({
 | 
			
		||||
              type: "mj",
 | 
			
		||||
              id: content["message_id"],
 | 
			
		||||
              id: key,
 | 
			
		||||
              icon: "/images/avatar/mid_journey.png",
 | 
			
		||||
              content: content
 | 
			
		||||
            });
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if (content.status === "Finished") {
 | 
			
		||||
            enableInput()
 | 
			
		||||
          }
 | 
			
		||||
        } else if (data.type === 'end') { // 消息接收完毕
 | 
			
		||||
          // 追加当前会话到会话列表
 | 
			
		||||
          if (isNewChat && newChatItem.value !== null) {
 | 
			
		||||
 
 | 
			
		||||
@@ -44,7 +44,6 @@ import {onMounted, ref} from "vue";
 | 
			
		||||
import {Lock, UserFilled} from "@element-plus/icons-vue";
 | 
			
		||||
import {httpPost} from "@/utils/http";
 | 
			
		||||
import {ElMessage} from "element-plus";
 | 
			
		||||
import {setSession} from "@/store/session";
 | 
			
		||||
import {useRouter} from "vue-router";
 | 
			
		||||
import FooterBar from "@/components/FooterBar.vue";
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user