mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 10:43:44 +08:00
feat: migrate the chatgpt-plus-ext project code to this project
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
package function
|
||||
package fun
|
||||
|
||||
import (
|
||||
"chatplus/service"
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/utils"
|
||||
)
|
||||
|
||||
@@ -9,10 +10,10 @@ import (
|
||||
|
||||
type FuncMidJourney struct {
|
||||
name string
|
||||
service *service.MjService
|
||||
service *mj.Service
|
||||
}
|
||||
|
||||
func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
|
||||
func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney {
|
||||
return FuncMidJourney{
|
||||
name: "MidJourney AI 绘画",
|
||||
service: mjService}
|
||||
@@ -21,10 +22,10 @@ func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
|
||||
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
||||
logger.Infof("MJ 绘画参数:%+v", params)
|
||||
prompt := utils.InterfaceToString(params["prompt"])
|
||||
f.service.PushTask(service.MjTask{
|
||||
f.service.PushTask(types.MjTask{
|
||||
SessionId: utils.InterfaceToString(params["session_id"]),
|
||||
Src: service.TaskSrcChat,
|
||||
Type: service.Image,
|
||||
Src: types.TaskSrcChat,
|
||||
Type: types.TaskImage,
|
||||
Prompt: prompt,
|
||||
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
|
||||
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
|
||||
@@ -1,9 +1,9 @@
|
||||
package function
|
||||
package fun
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service"
|
||||
"chatplus/service/mj"
|
||||
)
|
||||
|
||||
type Function interface {
|
||||
@@ -29,7 +29,7 @@ type dataItem struct {
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func NewFunctions(config *types.AppConfig, mjService *service.MjService) map[string]Function {
|
||||
func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function {
|
||||
return map[string]Function{
|
||||
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
|
||||
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
|
||||
@@ -1,4 +1,4 @@
|
||||
package function
|
||||
package fun
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
@@ -1,4 +1,4 @@
|
||||
package function
|
||||
package fun
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
@@ -1,4 +1,4 @@
|
||||
package function
|
||||
package fun
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
213
api/service/mj/bot.go
Normal file
213
api/service/mj/bot.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/utils"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MidJourney 机器人
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Bot struct {
|
||||
config *types.MidJourneyConfig
|
||||
bot *discordgo.Session
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.ProxyURL != "" {
|
||||
proxy, _ := url.Parse(config.ProxyURL)
|
||||
discord.Client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
}
|
||||
discord.Dialer = &websocket.Dialer{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
}
|
||||
}
|
||||
|
||||
return &Bot{
|
||||
config: &config.MjConfig,
|
||||
bot: discord,
|
||||
service: service,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Bot) Run() error {
|
||||
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
|
||||
b.bot.AddHandler(b.messageCreate)
|
||||
b.bot.AddHandler(b.messageUpdate)
|
||||
|
||||
logger.Info("Starting MidJourney Bot...")
|
||||
err := b.bot.Open()
|
||||
if err != nil {
|
||||
logger.Error("Error opening Discord connection:", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Starting MidJourney Bot successfully!")
|
||||
return nil
|
||||
}
|
||||
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
Start = TaskStatus("Started")
|
||||
Running = TaskStatus("Running")
|
||||
Stopped = TaskStatus("Stopped")
|
||||
Finished = TaskStatus("Finished")
|
||||
)
|
||||
|
||||
type Image struct {
|
||||
URL string `json:"url"`
|
||||
ProxyURL string `json:"proxy_url"`
|
||||
Filename string `json:"filename"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Size int `json:"size"`
|
||||
Hash string `json:"hash"`
|
||||
}
|
||||
|
||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
// ignore messages for other channels
|
||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
||||
return
|
||||
}
|
||||
// ignore messages for self
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
|
||||
var referenceId = ""
|
||||
if m.ReferencedMessage != nil {
|
||||
referenceId = m.ReferencedMessage.ID
|
||||
}
|
||||
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
||||
// parse content
|
||||
req := CBReq{
|
||||
MessageId: m.ID,
|
||||
ReferenceId: referenceId,
|
||||
Prompt: extractPrompt(m.Content),
|
||||
Content: m.Content,
|
||||
Progress: 0,
|
||||
Status: Start}
|
||||
b.service.Notify(req)
|
||||
return
|
||||
}
|
||||
|
||||
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
||||
}
|
||||
|
||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||
// ignore messages for other channels
|
||||
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
||||
return
|
||||
}
|
||||
// ignore messages for self
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
|
||||
|
||||
var referenceId = ""
|
||||
if m.ReferencedMessage != nil {
|
||||
referenceId = m.ReferencedMessage.ID
|
||||
}
|
||||
if strings.Contains(m.Content, "(Stopped)") {
|
||||
req := CBReq{
|
||||
MessageId: m.ID,
|
||||
ReferenceId: referenceId,
|
||||
Prompt: extractPrompt(m.Content),
|
||||
Content: m.Content,
|
||||
Progress: extractProgress(m.Content),
|
||||
Status: Stopped}
|
||||
b.service.Notify(req)
|
||||
return
|
||||
}
|
||||
|
||||
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
||||
|
||||
}
|
||||
|
||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
||||
progress := extractProgress(content)
|
||||
var status TaskStatus
|
||||
if progress == 100 {
|
||||
status = Finished
|
||||
} else {
|
||||
status = Running
|
||||
}
|
||||
for _, attachment := range attachments {
|
||||
if attachment.Width == 0 || attachment.Height == 0 {
|
||||
continue
|
||||
}
|
||||
image := Image{
|
||||
URL: attachment.URL,
|
||||
Height: attachment.Height,
|
||||
ProxyURL: attachment.ProxyURL,
|
||||
Width: attachment.Width,
|
||||
Size: attachment.Size,
|
||||
Filename: attachment.Filename,
|
||||
Hash: extractHashFromFilename(attachment.Filename),
|
||||
}
|
||||
req := CBReq{
|
||||
MessageId: messageId,
|
||||
ReferenceId: referenceId,
|
||||
Image: image,
|
||||
Prompt: extractPrompt(content),
|
||||
Content: content,
|
||||
Progress: progress,
|
||||
Status: status,
|
||||
}
|
||||
b.service.Notify(req)
|
||||
break // only get one image
|
||||
}
|
||||
}
|
||||
|
||||
// extract prompt from string
|
||||
func extractPrompt(input string) string {
|
||||
pattern := `\*\*(.*?)\*\*`
|
||||
re := regexp.MustCompile(pattern)
|
||||
matches := re.FindStringSubmatch(input)
|
||||
if len(matches) > 1 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractProgress(input string) int {
|
||||
pattern := `\((\d+)\%\)`
|
||||
re := regexp.MustCompile(pattern)
|
||||
matches := re.FindStringSubmatch(input)
|
||||
if len(matches) > 1 {
|
||||
return utils.IntValue(matches[1], 0)
|
||||
}
|
||||
return 100
|
||||
}
|
||||
|
||||
func extractHashFromFilename(filename string) string {
|
||||
if !strings.HasSuffix(filename, ".png") {
|
||||
return ""
|
||||
}
|
||||
|
||||
index := strings.LastIndex(filename, "_")
|
||||
if index != -1 {
|
||||
return filename[index+1 : len(filename)-4]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
144
api/service/mj/client.go
Normal file
144
api/service/mj/client.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MidJourney client
|
||||
|
||||
type Client struct {
|
||||
client *req.Client
|
||||
config *types.MidJourneyConfig
|
||||
}
|
||||
|
||||
func NewClient(config *types.AppConfig) *Client {
|
||||
client := req.C().SetTimeout(10 * time.Second)
|
||||
// set proxy URL
|
||||
if config.ProxyURL != "" {
|
||||
client.SetProxyURL(config.ProxyURL)
|
||||
}
|
||||
return &Client{client: client, config: &config.MjConfig}
|
||||
}
|
||||
|
||||
func (c *Client) Imagine(prompt string) error {
|
||||
interactionsReq := &InteractionsRequest{
|
||||
Type: 2,
|
||||
ApplicationID: ApplicationID,
|
||||
GuildID: c.config.GuildId,
|
||||
ChannelID: c.config.ChanelId,
|
||||
SessionID: SessionID,
|
||||
Data: map[string]any{
|
||||
"version": "1118961510123847772",
|
||||
"id": "938956540159881230",
|
||||
"name": "imagine",
|
||||
"type": "1",
|
||||
"options": []map[string]any{
|
||||
{
|
||||
"type": 3,
|
||||
"name": "prompt",
|
||||
"value": prompt,
|
||||
},
|
||||
},
|
||||
"application_command": map[string]any{
|
||||
"id": "938956540159881230",
|
||||
"application_id": ApplicationID,
|
||||
"version": "1118961510123847772",
|
||||
"default_permission": true,
|
||||
"default_member_permissions": nil,
|
||||
"type": 1,
|
||||
"nsfw": false,
|
||||
"name": "imagine",
|
||||
"description": "Create images with Midjourney",
|
||||
"dm_permission": true,
|
||||
"options": []map[string]any{
|
||||
{
|
||||
"type": 3,
|
||||
"name": "prompt",
|
||||
"description": "The prompt to imagine",
|
||||
"required": true,
|
||||
},
|
||||
},
|
||||
"attachments": []any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
url := "https://discord.com/api/v9/interactions"
|
||||
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(interactionsReq).
|
||||
Post(url)
|
||||
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("error with http request: %w%v", err, r.Err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upscale 放大指定的图片
|
||||
func (c *Client) Upscale(index int, messageId string, hash string) error {
|
||||
flags := 0
|
||||
interactionsReq := &InteractionsRequest{
|
||||
Type: 3,
|
||||
ApplicationID: ApplicationID,
|
||||
GuildID: c.config.GuildId,
|
||||
ChannelID: c.config.ChanelId,
|
||||
MessageFlags: &flags,
|
||||
MessageID: &messageId,
|
||||
SessionID: SessionID,
|
||||
Data: map[string]any{
|
||||
"component_type": 2,
|
||||
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
|
||||
},
|
||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
}
|
||||
|
||||
url := "https://discord.com/api/v9/interactions"
|
||||
var res InteractionsResult
|
||||
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(interactionsReq).
|
||||
SetErrorResult(&res).
|
||||
Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||
func (c *Client) Variation(index int, messageId string, hash string) error {
|
||||
flags := 0
|
||||
interactionsReq := &InteractionsRequest{
|
||||
Type: 3,
|
||||
ApplicationID: ApplicationID,
|
||||
GuildID: c.config.GuildId,
|
||||
ChannelID: c.config.ChanelId,
|
||||
MessageFlags: &flags,
|
||||
MessageID: &messageId,
|
||||
SessionID: SessionID,
|
||||
Data: map[string]any{
|
||||
"component_type": 2,
|
||||
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
|
||||
},
|
||||
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
}
|
||||
|
||||
url := "https://discord.com/api/v9/interactions"
|
||||
var res InteractionsResult
|
||||
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(interactionsReq).
|
||||
SetErrorResult(&res).
|
||||
Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
249
api/service/mj/service.go
Normal file
249
api/service/mj/service.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package mj
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/oss"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MJ 绘画服务
|
||||
|
||||
const RunningJobKey = "MidJourney_Running_Job"
|
||||
|
||||
type Service struct {
|
||||
client *Client
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
uploadManager *oss.UploaderManager
|
||||
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
|
||||
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||
return &Service{
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||
client: client,
|
||||
uploadManager: manager,
|
||||
Clients: types.NewLMap[string, *types.WsClient](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
proxyURL: config.ProxyURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting MidJourney job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var task types.MjTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
switch task.Type {
|
||||
case types.TaskImage:
|
||||
err = s.client.Imagine(task.Prompt)
|
||||
break
|
||||
case types.TaskUpscale:
|
||||
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
|
||||
|
||||
break
|
||||
case types.TaskVariation:
|
||||
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) PushTask(task types.MjTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *Service) Notify(data CBReq) {
|
||||
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
|
||||
if err != nil { // 过期任务,丢弃
|
||||
logger.Warn("任务已过期:", err)
|
||||
return
|
||||
}
|
||||
|
||||
var task types.MjTask
|
||||
err = utils.JsonDecode(taskString, &task)
|
||||
if err != nil { // 非标准任务,丢弃
|
||||
logger.Warn("任务解析失败:", err)
|
||||
return
|
||||
}
|
||||
|
||||
var job model.MidJourneyJob
|
||||
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||
if res.Error == nil && data.Status == Finished {
|
||||
logger.Warn("重复消息:", data.MessageId)
|
||||
return
|
||||
}
|
||||
|
||||
if task.Src == types.TaskSrcImg { // 绘画任务
|
||||
var job model.MidJourneyJob
|
||||
res := s.db.Where("id = ?", task.Id).First(&job)
|
||||
if res.Error != nil {
|
||||
logger.Warn("非法任务:", res.Error)
|
||||
return
|
||||
}
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Progress = data.Progress
|
||||
job.Prompt = data.Prompt
|
||||
job.Hash = data.Image.Hash
|
||||
|
||||
// 任务完成,将最终的图片下载下来
|
||||
if data.Progress == 100 {
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
if err != nil {
|
||||
logger.Error("error with download img: ", err.Error())
|
||||
return
|
||||
}
|
||||
job.ImgURL = imgURL
|
||||
} else {
|
||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||
job.ImgURL = data.Image.URL
|
||||
}
|
||||
res = s.db.Updates(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update job: ", res.Error)
|
||||
return
|
||||
}
|
||||
|
||||
var jobVo vo.MidJourneyJob
|
||||
err := utils.CopyObject(job, &jobVo)
|
||||
if err == nil {
|
||||
if data.Progress < 100 {
|
||||
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
|
||||
if err == nil {
|
||||
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
|
||||
// 推送任务到前端
|
||||
client := s.Clients.Get(task.SessionId)
|
||||
if client != nil {
|
||||
utils.ReplyChunkMessage(client, jobVo)
|
||||
}
|
||||
}
|
||||
|
||||
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
||||
wsClient := s.ChatClients.Get(task.SessionId)
|
||||
if data.Status == Finished {
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
// download image
|
||||
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||
if err != nil {
|
||||
logger.Error("error with download image: ", err)
|
||||
if wsClient != nil && data.ReferenceId != "" {
|
||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||
utils.ReplyMessage(wsClient, content)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
data.Image.URL = imgURL
|
||||
message := model.HistoryMessage{
|
||||
UserId: uint(task.UserId),
|
||||
ChatId: task.ChatId,
|
||||
RoleId: uint(task.RoleId),
|
||||
Type: types.MjMsg,
|
||||
Icon: task.Icon,
|
||||
Content: utils.JsonEncode(data),
|
||||
Tokens: 0,
|
||||
UseContext: false,
|
||||
}
|
||||
res = tx.Create(&message)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
// save the job
|
||||
job.UserId = task.UserId
|
||||
job.Type = task.Type.String()
|
||||
job.MessageId = data.MessageId
|
||||
job.ReferenceId = data.ReferenceId
|
||||
job.Prompt = data.Prompt
|
||||
job.ImgURL = imgURL
|
||||
job.Progress = data.Progress
|
||||
job.Hash = data.Image.Hash
|
||||
job.CreatedAt = time.Now()
|
||||
res = tx.Create(&job)
|
||||
if res.Error != nil {
|
||||
logger.Error("error with update database: ", err)
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
if wsClient == nil { // 客户端断线,则丢弃
|
||||
logger.Errorf("Client is offline: %+v", data)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Status == Finished {
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
||||
// 本次绘画完毕,移除客户端
|
||||
s.ChatClients.Delete(task.SessionId)
|
||||
} else {
|
||||
// 使用代理临时转发图片
|
||||
if data.Image.URL != "" {
|
||||
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
|
||||
if err == nil {
|
||||
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||
}
|
||||
}
|
||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||
}
|
||||
}
|
||||
|
||||
// 更新用户剩余绘图次数
|
||||
// TODO: 放大图片是否需要消耗绘图次数?
|
||||
if data.Status == Finished {
|
||||
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||
// 解除任务锁定
|
||||
s.redis.Del(context.Background(), RunningJobKey)
|
||||
}
|
||||
|
||||
}
|
||||
34
api/service/mj/types.go
Normal file
34
api/service/mj/types.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package mj
|
||||
|
||||
const (
|
||||
ApplicationID string = "936929561302675456"
|
||||
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
|
||||
)
|
||||
|
||||
type InteractionsRequest struct {
|
||||
Type int `json:"type"`
|
||||
ApplicationID string `json:"application_id"`
|
||||
MessageFlags *int `json:"message_flags,omitempty"`
|
||||
MessageID *string `json:"message_id,omitempty"`
|
||||
GuildID string `json:"guild_id"`
|
||||
ChannelID string `json:"channel_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Data map[string]any `json:"data"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
type InteractionsResult struct {
|
||||
Code int `json:"code"`
|
||||
Message string
|
||||
Error map[string]any
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
MessageId string `json:"message_id"`
|
||||
ReferenceId string `json:"reference_id"`
|
||||
Image Image `json:"image"`
|
||||
Content string `json:"content"`
|
||||
Prompt string `json:"prompt"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
// MJ 绘画服务
|
||||
|
||||
const MjRunningJobKey = "MidJourney_Running_Job"
|
||||
|
||||
type MjService struct {
|
||||
config types.ChatPlusExtConfig
|
||||
client *req.Client
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewMjService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *MjService {
|
||||
return &MjService{
|
||||
config: appConfig.ExtConfig,
|
||||
redis: client,
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("midjourney_task_queue", client),
|
||||
client: req.C().SetTimeout(30 * time.Second)}
|
||||
}
|
||||
|
||||
func (s *MjService) Run() {
|
||||
logger.Info("Starting MidJourney job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
var task types.MjTask
|
||||
err = s.taskQueue.LPop(&task)
|
||||
if err != nil {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
switch task.Type {
|
||||
case types.TaskImage:
|
||||
err = s.image(task.Prompt)
|
||||
break
|
||||
case types.TaskUpscale:
|
||||
err = s.upscale(MjUpscaleReq{
|
||||
Index: task.Index,
|
||||
MessageId: task.MessageId,
|
||||
MessageHash: task.MessageHash,
|
||||
})
|
||||
break
|
||||
case types.TaskVariation:
|
||||
err = s.variation(MjVariationReq{
|
||||
Index: task.Index,
|
||||
MessageId: task.MessageId,
|
||||
MessageHash: task.MessageHash,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MjService) PushTask(task types.MjTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *MjService) image(prompt string) error {
|
||||
logger.Infof("MJ 绘画参数:%+v", prompt)
|
||||
body := map[string]string{"prompt": prompt}
|
||||
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", s.config.Token).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("%v%v", r.String(), err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return errors.New(res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type MjUpscaleReq struct {
|
||||
Index int32 `json:"index"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
}
|
||||
|
||||
func (s *MjService) upscale(upReq MjUpscaleReq) error {
|
||||
url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", s.config.Token).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(upReq).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("%v%v", r.String(), err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return errors.New(res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type MjVariationReq struct {
|
||||
Index int32 `json:"index"`
|
||||
MessageId string `json:"message_id"`
|
||||
MessageHash string `json:"message_hash"`
|
||||
}
|
||||
|
||||
func (s *MjService) variation(upReq MjVariationReq) error {
|
||||
url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", s.config.Token).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(upReq).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("%v%v", r.String(), err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return errors.New(res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
169
api/service/sd/client.go
Normal file
169
api/service/sd/client.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
httpClient *req.Client
|
||||
config *types.StableDiffusionConfig
|
||||
}
|
||||
|
||||
func NewSdClient(config *types.AppConfig) *Client {
|
||||
return &Client{
|
||||
config: &config.SdConfig,
|
||||
httpClient: req.C(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Txt2Img(params types.SdTaskParams) error {
|
||||
var data []interface{}
|
||||
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[ParamKeys["task_id"]] = params.TaskId
|
||||
data[ParamKeys["prompt"]] = params.Prompt
|
||||
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
||||
data[ParamKeys["steps"]] = params.Steps
|
||||
data[ParamKeys["sampler"]] = params.Sampler
|
||||
data[ParamKeys["face_fix"]] = params.FaceFix
|
||||
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
||||
data[ParamKeys["seed"]] = params.Seed
|
||||
data[ParamKeys["height"]] = params.Height
|
||||
data[ParamKeys["width"]] = params.Width
|
||||
data[ParamKeys["hd_fix"]] = params.HdFix
|
||||
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
||||
data[ParamKeys["hd_scale"]] = params.HdScale
|
||||
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
||||
data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
|
||||
task := TaskInfo{
|
||||
TaskId: params.TaskId,
|
||||
Data: data,
|
||||
EventData: nil,
|
||||
FnIndex: 494,
|
||||
SessionHash: "ycaxgzm9ah",
|
||||
}
|
||||
|
||||
go func() {
|
||||
c.runTask(task, c.httpClient)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
|
||||
body := map[string]any{
|
||||
"data": taskInfo.Data,
|
||||
"event_data": taskInfo.EventData,
|
||||
"fn_index": taskInfo.FnIndex,
|
||||
"session_hash": taskInfo.SessionHash,
|
||||
}
|
||||
|
||||
var result = make(chan CBReq)
|
||||
go func() {
|
||||
var res struct {
|
||||
Data []interface{} `json:"data"`
|
||||
IsGenerating bool `json:"is_generating"`
|
||||
Duration float64 `json:"duration"`
|
||||
AverageDuration float64 `json:"average_duration"`
|
||||
}
|
||||
var cbReq = CBReq{TaskId: taskInfo.TaskId}
|
||||
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
|
||||
if err != nil {
|
||||
cbReq.Message = "error with send request: " + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
cbReq.Message = "error http status code: " + string(bytes)
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
var images []struct {
|
||||
Name string `json:"name"`
|
||||
Data interface{} `json:"data"`
|
||||
IsFile bool `json:"is_file"`
|
||||
}
|
||||
err = utils.ForceCovert(res.Data[0], &images)
|
||||
if err != nil {
|
||||
cbReq.Message = "error with decode image:" + err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
var info map[string]any
|
||||
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
||||
if err != nil {
|
||||
cbReq.Message = err.Error()
|
||||
cbReq.Success = false
|
||||
result <- cbReq
|
||||
return
|
||||
}
|
||||
|
||||
//for k, v := range info {
|
||||
// fmt.Println(k, " => ", v)
|
||||
//}
|
||||
cbReq.ImageName = images[0].Name
|
||||
cbReq.Seed = utils.InterfaceToString(info["seed"])
|
||||
cbReq.Success = true
|
||||
cbReq.Progress = 100
|
||||
result <- cbReq
|
||||
close(result)
|
||||
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case value := <-result:
|
||||
if value.Success {
|
||||
logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
|
||||
}
|
||||
return
|
||||
default:
|
||||
var progressReq = map[string]any{
|
||||
"id_task": taskInfo.TaskId,
|
||||
"id_live_preview": 1,
|
||||
}
|
||||
|
||||
var progressRes struct {
|
||||
Active bool `json:"active"`
|
||||
Queued bool `json:"queued"`
|
||||
Completed bool `json:"completed"`
|
||||
Progress float64 `json:"progress"`
|
||||
Eta float64 `json:"eta"`
|
||||
LivePreview string `json:"live_preview"`
|
||||
IDLivePreview int `json:"id_live_preview"`
|
||||
TextInfo interface{} `json:"textinfo"`
|
||||
}
|
||||
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
|
||||
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
|
||||
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
||||
logger.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if response.IsErrorState() {
|
||||
bytes, _ := io.ReadAll(response.Body)
|
||||
logger.Error(string(bytes))
|
||||
return
|
||||
}
|
||||
|
||||
cbReq.ImageData = progressRes.LivePreview
|
||||
cbReq.Progress = int(progressRes.Progress * 100)
|
||||
fmt.Println("Progress: ", progressRes.Progress)
|
||||
fmt.Println("Image: ", progressRes.LivePreview)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +1,42 @@
|
||||
package service
|
||||
package sd
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/service/mj"
|
||||
"chatplus/store"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
|
||||
const SdRunningJobKey = "StableDiffusion_Running_Job"
|
||||
const RunningJobKey = "StableDiffusion_Running_Job"
|
||||
|
||||
type SdService struct {
|
||||
config types.ChatPlusExtConfig
|
||||
client *req.Client
|
||||
type Service struct {
|
||||
taskQueue *store.RedisQueue
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
Client *Client
|
||||
}
|
||||
|
||||
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
|
||||
return &SdService{
|
||||
config: appConfig.ExtConfig,
|
||||
redis: client,
|
||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
|
||||
return &Service{
|
||||
redis: redisCli,
|
||||
db: db,
|
||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
|
||||
client: req.C().SetTimeout(30 * time.Second)}
|
||||
Client: client,
|
||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SdService) Run() {
|
||||
func (s *Service) Run() {
|
||||
logger.Info("Starting StableDiffusion job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
|
||||
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||
if err == nil { // 队列串行执行
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
@@ -51,7 +48,7 @@ func (s *SdService) Run() {
|
||||
continue
|
||||
}
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
err = s.txt2img(task.Params)
|
||||
err = s.Client.Txt2Img(task.Params)
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount <= 5 {
|
||||
@@ -65,31 +62,11 @@ func (s *SdService) Run() {
|
||||
// 更新任务的执行状态
|
||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SdService) PushTask(task types.SdTask) {
|
||||
func (s *Service) PushTask(task types.SdTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
func (s *SdService) txt2img(params types.SdParams) error {
|
||||
logger.Infof("SD 绘画参数:%+v", params)
|
||||
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", s.config.Token).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(params).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return fmt.Errorf("%v%v", r.String(), err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return errors.New(res.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
234
api/service/sd/types.go
Normal file
234
api/service/sd/types.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package sd
|
||||
|
||||
import logger2 "chatplus/logger"
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type TaskInfo struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Data interface{} `json:"data"`
|
||||
EventData interface{} `json:"event_data"`
|
||||
FnIndex int `json:"fn_index"`
|
||||
SessionHash string `json:"session_hash"`
|
||||
}
|
||||
|
||||
type CBReq struct {
|
||||
TaskId string
|
||||
ImageName string
|
||||
ImageData string
|
||||
Progress int
|
||||
Seed string
|
||||
Success bool
|
||||
Message string
|
||||
}
|
||||
|
||||
var ParamKeys = map[string]int{
|
||||
"task_id": 0,
|
||||
"prompt": 1,
|
||||
"negative_prompt": 2,
|
||||
"steps": 4,
|
||||
"sampler": 5,
|
||||
"face_fix": 6,
|
||||
"cfg_scale": 10,
|
||||
"seed": 11,
|
||||
"height": 17,
|
||||
"width": 18,
|
||||
"hd_fix": 19,
|
||||
"hd_redraw_rate": 20, //高清修复重绘幅度
|
||||
"hd_scale": 21, // 高清修复放大倍数
|
||||
"hd_scale_alg": 22, // 高清修复放大算法
|
||||
"hd_sample_num": 23, // 高清修复采样次数
|
||||
}
|
||||
|
||||
const Text2ImgParamTemplate = `[
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
[],
|
||||
30,
|
||||
"DPM++ SDE Karras",
|
||||
false,
|
||||
false,
|
||||
1,
|
||||
1,
|
||||
7.5,
|
||||
-1,
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
512,
|
||||
512,
|
||||
true,
|
||||
0.7,
|
||||
2,
|
||||
"Latent",
|
||||
10,
|
||||
0,
|
||||
0,
|
||||
"Use same sampler",
|
||||
"",
|
||||
"",
|
||||
[],
|
||||
"None",
|
||||
false,
|
||||
"MultiDiffusion",
|
||||
false,
|
||||
true,
|
||||
1024,
|
||||
1024,
|
||||
96,
|
||||
96,
|
||||
48,
|
||||
4,
|
||||
"None",
|
||||
2,
|
||||
false,
|
||||
10,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
0.4,
|
||||
0.4,
|
||||
0.2,
|
||||
0.2,
|
||||
"",
|
||||
"",
|
||||
"Background",
|
||||
0.2,
|
||||
-1,
|
||||
false,
|
||||
3072,
|
||||
192,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
"",
|
||||
0.5,
|
||||
true,
|
||||
false,
|
||||
"",
|
||||
"Lerp",
|
||||
false,
|
||||
"🔄",
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
"positive",
|
||||
"comma",
|
||||
0,
|
||||
false,
|
||||
false,
|
||||
"",
|
||||
"Seed",
|
||||
"",
|
||||
[],
|
||||
"Nothing",
|
||||
"",
|
||||
[],
|
||||
"Nothing",
|
||||
"",
|
||||
[],
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
false,
|
||||
50
|
||||
]`
|
||||
87
api/service/wx/bot.go
Normal file
87
api/service/wx/bot.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package wx
|
||||
|
||||
import (
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store/model"
|
||||
"github.com/eatmoreapple/openwechat"
|
||||
"github.com/skip2/go-qrcode"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 微信收款机器人
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Bot struct {
|
||||
bot *openwechat.Bot
|
||||
token string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewWeChatBot(db *gorm.DB) *Bot {
|
||||
bot := openwechat.DefaultBot(openwechat.Desktop)
|
||||
return &Bot{
|
||||
bot: bot,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) Run() error {
|
||||
logger.Info("Starting WeChat Bot...")
|
||||
|
||||
// set message handler
|
||||
b.bot.MessageHandler = func(msg *openwechat.Message) {
|
||||
b.messageHandler(msg)
|
||||
}
|
||||
// scan code login callback
|
||||
b.bot.UUIDCallback = b.qrCodeCallBack
|
||||
|
||||
err := b.bot.Login()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("微信登录成功!")
|
||||
return nil
|
||||
}
|
||||
|
||||
// message handler
|
||||
func (b *Bot) messageHandler(msg *openwechat.Message) {
|
||||
sender, err := msg.Sender()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 只处理微信支付的推送消息
|
||||
if sender.NickName == "微信支付" ||
|
||||
msg.MsgType == openwechat.MsgTypeApp ||
|
||||
msg.AppMsgType == openwechat.AppMsgTypeUrl {
|
||||
// 解析支付金额
|
||||
message, err := parseTransactionMessage(msg.Content)
|
||||
if err == nil {
|
||||
transaction := extractTransaction(message)
|
||||
logger.Infof("解析到收款信息:%+v", transaction)
|
||||
var item model.Reward
|
||||
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
|
||||
if res.Error == nil {
|
||||
logger.Error("当前交易 ID 己经存在!")
|
||||
return
|
||||
}
|
||||
|
||||
res = b.db.Create(&model.Reward{
|
||||
TxId: transaction.TransId,
|
||||
Amount: transaction.Amount,
|
||||
Remark: transaction.Remark,
|
||||
Status: false,
|
||||
})
|
||||
if res.Error != nil {
|
||||
logger.Errorf("交易保存失败: %v", res.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) qrCodeCallBack(uuid string) {
|
||||
logger.Info("请使用微信扫描下面二维码登录")
|
||||
q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
|
||||
logger.Info(q.ToString(true))
|
||||
}
|
||||
68
api/service/wx/tranaction.go
Normal file
68
api/service/wx/tranaction.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package wx
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Message 转账消息
|
||||
type Message struct {
|
||||
XMLName xml.Name `xml:"msg"`
|
||||
AppMsg struct {
|
||||
Des string `xml:"des"`
|
||||
Url string `xml:"url"`
|
||||
} `xml:"appmsg"`
|
||||
}
|
||||
|
||||
// Transaction 解析后的交易信息
|
||||
type Transaction struct {
|
||||
TransId string `json:"trans_id"` // 微信转账交易 ID
|
||||
Amount float64 `json:"amount"` // 微信转账交易金额
|
||||
Remark string `json:"remark"` // 转账备注
|
||||
}
|
||||
|
||||
// 解析微信转账消息
|
||||
func parseTransactionMessage(xmlData string) (*Message, error) {
|
||||
var msg Message
|
||||
if err := xml.Unmarshal([]byte(xmlData), &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// 导出交易信息
|
||||
func extractTransaction(message *Message) Transaction {
|
||||
var tx = Transaction{}
|
||||
// 导出交易金额和备注
|
||||
lines := strings.Split(message.AppMsg.Des, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
// 解析收款金额
|
||||
prefix := "收款金额¥"
|
||||
if strings.HasPrefix(line, prefix) {
|
||||
if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
|
||||
tx.Amount = value
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 解析收款备注
|
||||
prefix = "付款方备注"
|
||||
if strings.HasPrefix(line, prefix) {
|
||||
tx.Remark = line[len(prefix):]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 解析交易 ID
|
||||
index := strings.Index(message.AppMsg.Url, "trans_id=")
|
||||
if index != -1 {
|
||||
end := strings.LastIndex(message.AppMsg.Url, "&")
|
||||
tx.TransId = strings.TrimSpace(message.AppMsg.Url[index+9 : end])
|
||||
}
|
||||
return tx
|
||||
}
|
||||
Reference in New Issue
Block a user