mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	feat: migrate the chatgpt-plus-ext project code to this project
This commit is contained in:
		@@ -2,7 +2,7 @@ package core
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/function"
 | 
			
		||||
	"chatplus/service/fun"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
@@ -33,11 +33,10 @@ type AppServer struct {
 | 
			
		||||
	ChatSession   *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合
 | 
			
		||||
	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
			
		||||
	Functions     map[string]function.Function
 | 
			
		||||
	MjTaskClients *types.LMap[string, *types.WsClient]
 | 
			
		||||
	Functions     map[string]fun.Function
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
 | 
			
		||||
func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer {
 | 
			
		||||
	gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	gin.DefaultWriter = io.Discard
 | 
			
		||||
	return &AppServer{
 | 
			
		||||
@@ -48,7 +47,6 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio
 | 
			
		||||
		ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
		MjTaskClients: types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		Functions:     functions,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -26,7 +26,6 @@ func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
			MaxAge:    86400,
 | 
			
		||||
		},
 | 
			
		||||
		ApiConfig: types.ChatPlusApiConfig{},
 | 
			
		||||
		ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)},
 | 
			
		||||
		OSS: types.OSSConfig{
 | 
			
		||||
			Active: "local",
 | 
			
		||||
			Local: types.LocalStorageConfig{
 | 
			
		||||
@@ -34,6 +33,9 @@ func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
				BasePath: "./static/upload",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		MjConfig:  types.MidJourneyConfig{Enabled: false},
 | 
			
		||||
		SdConfig:  types.StableDiffusionConfig{Enabled: false},
 | 
			
		||||
		WeChatBot: false,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -16,10 +16,11 @@ type AppConfig struct {
 | 
			
		||||
	Redis         RedisConfig       // redis 连接信息
 | 
			
		||||
	ApiConfig     ChatPlusApiConfig // ChatPlus API authorization configs
 | 
			
		||||
	AesEncryptKey string
 | 
			
		||||
	SmsConfig     AliYunSmsConfig   // AliYun send message service config
 | 
			
		||||
	ExtConfig     ChatPlusExtConfig // ChatPlus extensions callback api config
 | 
			
		||||
 | 
			
		||||
	OSS OSSConfig // OSS config
 | 
			
		||||
	SmsConfig     AliYunSmsConfig       // AliYun send message service config
 | 
			
		||||
	OSS           OSSConfig             // OSS config
 | 
			
		||||
	MjConfig      MidJourneyConfig      // mj 绘画配置
 | 
			
		||||
	WeChatBot     bool                  // 是否启用微信机器人
 | 
			
		||||
	SdConfig      StableDiffusionConfig // sd 绘画配置
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatPlusApiConfig struct {
 | 
			
		||||
@@ -28,9 +29,22 @@ type ChatPlusApiConfig struct {
 | 
			
		||||
	Token  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatPlusExtConfig struct {
 | 
			
		||||
	ApiURL string
 | 
			
		||||
	Token  string
 | 
			
		||||
type MidJourneyConfig struct {
 | 
			
		||||
	Enabled   bool
 | 
			
		||||
	UserToken string
 | 
			
		||||
	BotToken  string
 | 
			
		||||
	GuildId   string // Server ID
 | 
			
		||||
	ChanelId  string // Chanel ID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WeChatConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StableDiffusionConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	ApiURL  string
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliYunSmsConfig struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -33,14 +33,24 @@ type MjTask struct {
 | 
			
		||||
	ChatId      string   `json:"chat_id,omitempty"`
 | 
			
		||||
	RoleId      int      `json:"role_id,omitempty"`
 | 
			
		||||
	Icon        string   `json:"icon,omitempty"`
 | 
			
		||||
	Index       int32    `json:"index,omitempty"`
 | 
			
		||||
	Index       int      `json:"index,omitempty"`
 | 
			
		||||
	MessageId   string   `json:"message_id,omitempty"`
 | 
			
		||||
	MessageHash string   `json:"message_hash,omitempty"`
 | 
			
		||||
	RetryCount  int      `json:"retry_count"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SdParams stable diffusion 绘画参数
 | 
			
		||||
type SdParams struct {
 | 
			
		||||
type SdTask struct {
 | 
			
		||||
	Id         int          `json:"id"`
 | 
			
		||||
	SessionId  string       `json:"session_id"`
 | 
			
		||||
	Src        TaskSrc      `json:"src"`
 | 
			
		||||
	Type       TaskType     `json:"type"`
 | 
			
		||||
	UserId     int          `json:"user_id"`
 | 
			
		||||
	Prompt     string       `json:"prompt,omitempty"`
 | 
			
		||||
	Params     SdTaskParams `json:"params"`
 | 
			
		||||
	RetryCount int          `json:"retry_count"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SdTaskParams struct {
 | 
			
		||||
	TaskId         string  `json:"task_id"`
 | 
			
		||||
	Prompt         string  `json:"prompt"`
 | 
			
		||||
	NegativePrompt string  `json:"negative_prompt"`
 | 
			
		||||
@@ -57,14 +67,3 @@ type SdParams struct {
 | 
			
		||||
	HdScaleAlg     string  `json:"hd_scale_alg"`
 | 
			
		||||
	HdSampleNum    int     `json:"hd_sample_num"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SdTask struct {
 | 
			
		||||
	Id         int            `json:"id"`
 | 
			
		||||
	SessionId  string         `json:"session_id"`
 | 
			
		||||
	Src        types.TaskSrc  `json:"src"`
 | 
			
		||||
	Type       types.TaskType `json:"type"`
 | 
			
		||||
	UserId     int            `json:"user_id"`
 | 
			
		||||
	Prompt     string         `json:"prompt,omitempty"`
 | 
			
		||||
	Params     types.SdParams `json:"params"`
 | 
			
		||||
	RetryCount int            `json:"retry_count"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,9 @@ go 1.19
 | 
			
		||||
require (
 | 
			
		||||
	github.com/BurntSushi/toml v1.1.0
 | 
			
		||||
	github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
 | 
			
		||||
	github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
 | 
			
		||||
	github.com/bwmarrin/discordgo v0.27.1
 | 
			
		||||
	github.com/eatmoreapple/openwechat v1.2.1
 | 
			
		||||
	github.com/gin-gonic/gin v1.9.1
 | 
			
		||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
			
		||||
	github.com/golang-jwt/jwt/v5 v5.0.0
 | 
			
		||||
@@ -14,6 +17,7 @@ require (
 | 
			
		||||
	github.com/minio/minio-go/v7 v7.0.62
 | 
			
		||||
	github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
 | 
			
		||||
	github.com/qiniu/go-sdk/v7 v7.17.1
 | 
			
		||||
	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 | 
			
		||||
	github.com/syndtr/goleveldb v1.0.0
 | 
			
		||||
	go.uber.org/zap v1.23.0
 | 
			
		||||
	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 | 
			
		||||
@@ -21,7 +25,6 @@ require (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible // indirect
 | 
			
		||||
	github.com/andybalholm/brotli v1.0.4 // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
 | 
			
		||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
 | 
			
		||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
 | 
			
		||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
 | 
			
		||||
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
 | 
			
		||||
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
 | 
			
		||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
			
		||||
@@ -25,6 +27,8 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
 | 
			
		||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 | 
			
		||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
 | 
			
		||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
 | 
			
		||||
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
 | 
			
		||||
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
 | 
			
		||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
 | 
			
		||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
 | 
			
		||||
@@ -75,6 +79,7 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
 | 
			
		||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
 | 
			
		||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 | 
			
		||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
 | 
			
		||||
@@ -168,6 +173,8 @@ github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
 | 
			
		||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
 | 
			
		||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
 | 
			
		||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 | 
			
		||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 | 
			
		||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 | 
			
		||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 | 
			
		||||
@@ -209,6 +216,7 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
 | 
			
		||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
 | 
			
		||||
 
 | 
			
		||||
@@ -150,7 +150,7 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
					content := data
 | 
			
		||||
					if functionName == types.FuncMidJourney {
 | 
			
		||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
			
		||||
						h.App.MjTaskClients.Put(session.SessionId, ws)
 | 
			
		||||
						h.mjService.ChatClients.Put(session.SessionId, ws)
 | 
			
		||||
						// update user's img_calls
 | 
			
		||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
					}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
@@ -27,13 +28,14 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
 | 
			
		||||
 | 
			
		||||
type ChatHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db      *gorm.DB
 | 
			
		||||
	leveldb *store.LevelDB
 | 
			
		||||
	redis   *redis.Client
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
	leveldb   *store.LevelDB
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	mjService *mj.Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client) *ChatHandler {
 | 
			
		||||
	handler := ChatHandler{db: db, leveldb: levelDB, redis: redis}
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
 | 
			
		||||
	handler := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service}
 | 
			
		||||
	handler.App = app
 | 
			
		||||
	return &handler
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@ package handler
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
@@ -21,28 +21,11 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TaskStatus string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	Stopped  = TaskStatus("Stopped")
 | 
			
		||||
	Finished = TaskStatus("Finished")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Image struct {
 | 
			
		||||
	URL      string `json:"url"`
 | 
			
		||||
	ProxyURL string `json:"proxy_url"`
 | 
			
		||||
	Filename string `json:"filename"`
 | 
			
		||||
	Width    int    `json:"width"`
 | 
			
		||||
	Height   int    `json:"height"`
 | 
			
		||||
	Size     int    `json:"size"`
 | 
			
		||||
	Hash     string `json:"hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidJourneyHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis           *redis.Client
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
	mjService       *service.MjService
 | 
			
		||||
	mjService       *mj.Service
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
	lock            sync.Mutex
 | 
			
		||||
	clients         *types.LMap[string, *types.WsClient]
 | 
			
		||||
@@ -53,7 +36,7 @@ func NewMidJourneyHandler(
 | 
			
		||||
	client *redis.Client,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	manager *oss.UploaderManager,
 | 
			
		||||
	mjService *service.MjService) *MidJourneyHandler {
 | 
			
		||||
	mjService *mj.Service) *MidJourneyHandler {
 | 
			
		||||
	h := MidJourneyHandler{
 | 
			
		||||
		redis:           client,
 | 
			
		||||
		db:              db,
 | 
			
		||||
@@ -66,16 +49,6 @@ func NewMidJourneyHandler(
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mjNotifyData struct {
 | 
			
		||||
	MessageId   string     `json:"message_id"`
 | 
			
		||||
	ReferenceId string     `json:"reference_id"`
 | 
			
		||||
	Image       Image      `json:"image"`
 | 
			
		||||
	Content     string     `json:"content"`
 | 
			
		||||
	Prompt      string     `json:"prompt"`
 | 
			
		||||
	Status      TaskStatus `json:"status"`
 | 
			
		||||
	Progress    int        `json:"progress"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
@@ -92,189 +65,6 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
 | 
			
		||||
	token := c.GetHeader("Authorization")
 | 
			
		||||
	if token != h.App.Config.ExtConfig.Token {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var data mjNotifyData
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf("收到 MidJourney 回调请求:%+v", data)
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	err, finished := h.notifyHandler(c, data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 解除任务锁定
 | 
			
		||||
	if finished && (data.Status == Finished || data.Status == Stopped) {
 | 
			
		||||
		h.redis.Del(c, service.MjRunningJobKey)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) {
 | 
			
		||||
	taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
 | 
			
		||||
	if err != nil { // 过期任务,丢弃
 | 
			
		||||
		logger.Warn("任务已过期:", err)
 | 
			
		||||
		return nil, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var task types.MjTask
 | 
			
		||||
	err = utils.JsonDecode(taskString, &task)
 | 
			
		||||
	if err != nil { // 非标准任务,丢弃
 | 
			
		||||
		logger.Warn("任务解析失败:", err)
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	res := h.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
			
		||||
	if res.Error == nil && data.Status == Finished {
 | 
			
		||||
		logger.Warn("重复消息:", data.MessageId)
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if task.Src == types.TaskSrcImg { // 绘画任务
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		res := h.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Warn("非法任务:", res.Error)
 | 
			
		||||
			return nil, false
 | 
			
		||||
		}
 | 
			
		||||
		job.MessageId = data.MessageId
 | 
			
		||||
		job.ReferenceId = data.ReferenceId
 | 
			
		||||
		job.Progress = data.Progress
 | 
			
		||||
		job.Prompt = data.Prompt
 | 
			
		||||
		job.Hash = data.Image.Hash
 | 
			
		||||
 | 
			
		||||
		// 任务完成,将最终的图片下载下来
 | 
			
		||||
		if data.Progress == 100 {
 | 
			
		||||
			imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download img: ", err.Error())
 | 
			
		||||
				return err, false
 | 
			
		||||
			}
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
		} else {
 | 
			
		||||
			// 临时图片直接保存,访问的时候使用代理进行转发
 | 
			
		||||
			job.ImgURL = data.Image.URL
 | 
			
		||||
		}
 | 
			
		||||
		res = h.db.Updates(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update job: ", res.Error)
 | 
			
		||||
			return res.Error, false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var jobVo vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			if data.Progress < 100 {
 | 
			
		||||
				image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := h.clients.Get(task.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	} else if task.Src == types.TaskSrcChat { // 聊天任务
 | 
			
		||||
		wsClient := h.App.MjTaskClients.Get(task.SessionId)
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
				content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
 | 
			
		||||
				utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			}
 | 
			
		||||
			// download image
 | 
			
		||||
			imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
				if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
					content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
 | 
			
		||||
					utils.ReplyMessage(wsClient, content)
 | 
			
		||||
				}
 | 
			
		||||
				return err, false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tx := h.db.Begin()
 | 
			
		||||
			data.Image.URL = imgURL
 | 
			
		||||
			message := model.HistoryMessage{
 | 
			
		||||
				UserId:     uint(task.UserId),
 | 
			
		||||
				ChatId:     task.ChatId,
 | 
			
		||||
				RoleId:     uint(task.RoleId),
 | 
			
		||||
				Type:       types.MjMsg,
 | 
			
		||||
				Icon:       task.Icon,
 | 
			
		||||
				Content:    utils.JsonEncode(data),
 | 
			
		||||
				Tokens:     0,
 | 
			
		||||
				UseContext: false,
 | 
			
		||||
			}
 | 
			
		||||
			res = tx.Create(&message)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				return res.Error, false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// save the job
 | 
			
		||||
			job.UserId = task.UserId
 | 
			
		||||
			job.Type = task.Type.String()
 | 
			
		||||
			job.MessageId = data.MessageId
 | 
			
		||||
			job.ReferenceId = data.ReferenceId
 | 
			
		||||
			job.Prompt = data.Prompt
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
			job.Progress = data.Progress
 | 
			
		||||
			job.Hash = data.Image.Hash
 | 
			
		||||
			job.CreatedAt = time.Now()
 | 
			
		||||
			res = tx.Create(&job)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				tx.Rollback()
 | 
			
		||||
				return res.Error, false
 | 
			
		||||
			}
 | 
			
		||||
			tx.Commit()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if wsClient == nil { // 客户端断线,则丢弃
 | 
			
		||||
			logger.Errorf("Client is offline: %+v", data)
 | 
			
		||||
			return nil, true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
			// 本次绘画完毕,移除客户端
 | 
			
		||||
			h.App.MjTaskClients.Delete(task.SessionId)
 | 
			
		||||
		} else {
 | 
			
		||||
			// 使用代理临时转发图片
 | 
			
		||||
			if data.Image.URL != "" {
 | 
			
		||||
				image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户剩余绘图次数
 | 
			
		||||
	// TODO: 放大图片是否需要消耗绘图次数?
 | 
			
		||||
	if data.Status == Finished {
 | 
			
		||||
		h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -376,7 +166,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
type reqVo struct {
 | 
			
		||||
	Src         string `json:"src"`
 | 
			
		||||
	Index       int32  `json:"index"`
 | 
			
		||||
	Index       int    `json:"index"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
	SessionId   string `json:"session_id"`
 | 
			
		||||
@@ -443,15 +233,16 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	wsClient := h.App.ChatClients.Get(data.SessionId)
 | 
			
		||||
	if wsClient != nil {
 | 
			
		||||
		content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
		utils.ReplyMessage(wsClient, content)
 | 
			
		||||
		if h.App.MjTaskClients.Get(data.SessionId) == nil {
 | 
			
		||||
			h.App.MjTaskClients.Put(data.SessionId, wsClient)
 | 
			
		||||
	if src == types.TaskSrcChat {
 | 
			
		||||
		wsClient := h.App.ChatClients.Get(data.SessionId)
 | 
			
		||||
		if wsClient != nil {
 | 
			
		||||
			content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
			utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			if h.mjService.ChatClients.Get(data.SessionId) == nil {
 | 
			
		||||
				h.mjService.ChatClients.Put(data.SessionId, wsClient)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -513,13 +304,15 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// 从聊天窗口发送的请求,记录客户端信息
 | 
			
		||||
	wsClient := h.App.ChatClients.Get(data.SessionId)
 | 
			
		||||
	if wsClient != nil {
 | 
			
		||||
		content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
		utils.ReplyMessage(wsClient, content)
 | 
			
		||||
		if h.App.MjTaskClients.Get(data.SessionId) == nil {
 | 
			
		||||
			h.App.MjTaskClients.Put(data.SessionId, wsClient)
 | 
			
		||||
	if src == types.TaskSrcChat {
 | 
			
		||||
		// 从聊天窗口发送的请求,记录客户端信息
 | 
			
		||||
		wsClient := h.mjService.ChatClients.Get(data.SessionId)
 | 
			
		||||
		if wsClient != nil {
 | 
			
		||||
			content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
			utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			if h.mjService.Clients.Get(data.SessionId) == nil {
 | 
			
		||||
				h.mjService.Clients.Put(data.SessionId, wsClient)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 
 | 
			
		||||
@@ -150,7 +150,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
					content := data
 | 
			
		||||
					if functionName == types.FuncMidJourney {
 | 
			
		||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
			
		||||
						h.App.MjTaskClients.Put(session.SessionId, ws)
 | 
			
		||||
						h.mjService.ChatClients.Put(session.SessionId, ws)
 | 
			
		||||
						// update user's img_calls
 | 
			
		||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
					}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,50 +22,6 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RewardHandler) Notify(c *gin.Context) {
 | 
			
		||||
	token := c.GetHeader("Authorization")
 | 
			
		||||
	if token != h.App.Config.ExtConfig.Token {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		TransId string  `json:"trans_id"` // 微信转账交易 ID
 | 
			
		||||
		Amount  float64 `json:"amount"`   // 微信转账交易金额
 | 
			
		||||
		Remark  string  `json:"remark"`   // 转账备注
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.Amount <= 0 {
 | 
			
		||||
		resp.ERROR(c, "Amount should not be 0")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Infof("收到众筹收款信息: %+v", data)
 | 
			
		||||
	var item model.Reward
 | 
			
		||||
	res := h.db.Where("tx_id = ?", data.TransId).First(&item)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		resp.ERROR(c, "当前交易 ID 己经存在!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res = h.db.Create(&model.Reward{
 | 
			
		||||
		TxId:   data.TransId,
 | 
			
		||||
		Amount: data.Amount,
 | 
			
		||||
		Remark: data.Remark,
 | 
			
		||||
		Status: false,
 | 
			
		||||
	})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Errorf("交易保存失败: %v", res.Error)
 | 
			
		||||
		resp.ERROR(c, "交易保存失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Verify 打赏码核销
 | 
			
		||||
func (h *RewardHandler) Verify(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,315 +1,316 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SdJobHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis           *redis.Client
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
	mjService       *service.MjService
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
	lock            sync.Mutex
 | 
			
		||||
	clients         *types.LMap[string, *types.WsClient]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSdJobHandler(
 | 
			
		||||
	app *core.AppServer,
 | 
			
		||||
	client *redis.Client,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	manager *oss.UploaderManager,
 | 
			
		||||
	mjService *service.MjService) *MidJourneyHandler {
 | 
			
		||||
	h := MidJourneyHandler{
 | 
			
		||||
		redis:           client,
 | 
			
		||||
		db:              db,
 | 
			
		||||
		uploaderManager: manager,
 | 
			
		||||
		lock:            sync.Mutex{},
 | 
			
		||||
		mjService:       mjService,
 | 
			
		||||
		clients:         types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
	}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *SdJobHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionId := c.Query("session_id")
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	// 删除旧的连接
 | 
			
		||||
	h.clients.Delete(sessionId)
 | 
			
		||||
	h.clients.Put(sessionId, client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sdNotifyData struct {
 | 
			
		||||
	TaskId    string
 | 
			
		||||
	ImageName string
 | 
			
		||||
	ImageData string
 | 
			
		||||
	Progress  int
 | 
			
		||||
	Seed      string
 | 
			
		||||
	Success   bool
 | 
			
		||||
	Message   string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SdJobHandler) Notify(c *gin.Context) {
 | 
			
		||||
	token := c.GetHeader("Authorization")
 | 
			
		||||
	if token != h.App.Config.ExtConfig.Token {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var data sdNotifyData
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf("收到 MidJourney 回调请求:%+v", data)
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	err, finished := h.notifyHandler(c, data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 解除任务锁定
 | 
			
		||||
	if finished && (data.Progress == 100) {
 | 
			
		||||
		h.redis.Del(c, service.MjRunningJobKey)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
 | 
			
		||||
	taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
 | 
			
		||||
	if err != nil { // 过期任务,丢弃
 | 
			
		||||
		logger.Warn("任务已过期:", err)
 | 
			
		||||
		return nil, true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var task types.SdTask
 | 
			
		||||
	err = utils.JsonDecode(taskString, &task)
 | 
			
		||||
	if err != nil { // 非标准任务,丢弃
 | 
			
		||||
		logger.Warn("任务解析失败:", err)
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var job model.SdJob
 | 
			
		||||
	res := h.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Warn("非法任务:", res.Error)
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
	job.Params = utils.JsonEncode(task.Params)
 | 
			
		||||
	job.ReferenceId = data.ImageData
 | 
			
		||||
	job.Progress = data.Progress
 | 
			
		||||
	job.Prompt = data.Prompt
 | 
			
		||||
	job.Hash = data.Image.Hash
 | 
			
		||||
 | 
			
		||||
	// 任务完成,将最终的图片下载下来
 | 
			
		||||
	if data.Progress == 100 {
 | 
			
		||||
		imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download img: ", err.Error())
 | 
			
		||||
			return err, false
 | 
			
		||||
		}
 | 
			
		||||
		job.ImgURL = imgURL
 | 
			
		||||
	} else {
 | 
			
		||||
		// 临时图片直接保存,访问的时候使用代理进行转发
 | 
			
		||||
		job.ImgURL = data.Image.URL
 | 
			
		||||
	}
 | 
			
		||||
	res = h.db.Updates(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("error with update job: ", res.Error)
 | 
			
		||||
		return res.Error, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobVo vo.MidJourneyJob
 | 
			
		||||
	err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		if data.Progress < 100 {
 | 
			
		||||
			image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 推送任务到前端
 | 
			
		||||
		client := h.clients.Get(task.SessionId)
 | 
			
		||||
		if client != nil {
 | 
			
		||||
			utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户剩余绘图次数
 | 
			
		||||
	if data.Progress == 100 {
 | 
			
		||||
		h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.ImgCalls <= 0 {
 | 
			
		||||
		resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		SessionId string  `json:"session_id"`
 | 
			
		||||
		Prompt    string  `json:"prompt"`
 | 
			
		||||
		Rate      string  `json:"rate"`
 | 
			
		||||
		Model     string  `json:"model"`
 | 
			
		||||
		Chaos     int     `json:"chaos"`
 | 
			
		||||
		Raw       bool    `json:"raw"`
 | 
			
		||||
		Seed      int64   `json:"seed"`
 | 
			
		||||
		Stylize   int     `json:"stylize"`
 | 
			
		||||
		Img       string  `json:"img"`
 | 
			
		||||
		Weight    float32 `json:"weight"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !h.checkLimits(c) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var prompt = data.Prompt
 | 
			
		||||
	if data.Rate != "" && !strings.Contains(prompt, "--ar") {
 | 
			
		||||
		prompt += " --ar " + data.Rate
 | 
			
		||||
	}
 | 
			
		||||
	if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
 | 
			
		||||
		prompt += fmt.Sprintf(" --seed %d", data.Seed)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
 | 
			
		||||
		prompt += fmt.Sprintf(" --s %d", data.Stylize)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
 | 
			
		||||
		prompt += fmt.Sprintf(" --c %d", data.Chaos)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Img != "" {
 | 
			
		||||
		prompt = fmt.Sprintf("%s %s", data.Img, prompt)
 | 
			
		||||
		if data.Weight > 0 {
 | 
			
		||||
			prompt += fmt.Sprintf(" --iw %f", data.Weight)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if data.Raw {
 | 
			
		||||
		prompt += " --style raw"
 | 
			
		||||
	}
 | 
			
		||||
	if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
 | 
			
		||||
		prompt += data.Model
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	job := model.MidJourneyJob{
 | 
			
		||||
		Type:      service.Image.String(),
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		Progress:  0,
 | 
			
		||||
		Prompt:    prompt,
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	if res := h.db.Create(&job); res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.mjService.PushTask(service.MjTask{
 | 
			
		||||
		Id:        int(job.Id),
 | 
			
		||||
		SessionId: data.SessionId,
 | 
			
		||||
		Src:       service.TaskSrcImg,
 | 
			
		||||
		Type:      service.Image,
 | 
			
		||||
		Prompt:    prompt,
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var jobVo vo.MidJourneyJob
 | 
			
		||||
	err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		// 推送任务到前端
 | 
			
		||||
		client := h.clients.Get(data.SessionId)
 | 
			
		||||
		if client != nil {
 | 
			
		||||
			utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetInt(c, "status", 0)
 | 
			
		||||
	var items []model.MidJourneyJob
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	userId, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	if status == 1 {
 | 
			
		||||
		res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
 | 
			
		||||
	} else {
 | 
			
		||||
		res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
 | 
			
		||||
	}
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, types.NoData)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.MidJourneyJob, 0)
 | 
			
		||||
	for _, item := range items {
 | 
			
		||||
		var job vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(item, &job)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if item.Progress < 100 {
 | 
			
		||||
			// 30 分钟还没完成的任务直接删除
 | 
			
		||||
			if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
 | 
			
		||||
				h.db.Delete(&item)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
 | 
			
		||||
				image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
//
 | 
			
		||||
//import (
 | 
			
		||||
//	"chatplus/core"
 | 
			
		||||
//	"chatplus/core/types"
 | 
			
		||||
//	"chatplus/service"
 | 
			
		||||
//	"chatplus/service/oss"
 | 
			
		||||
//	"chatplus/store/model"
 | 
			
		||||
//	"chatplus/store/vo"
 | 
			
		||||
//	"chatplus/utils"
 | 
			
		||||
//	"chatplus/utils/resp"
 | 
			
		||||
//	"encoding/base64"
 | 
			
		||||
//	"fmt"
 | 
			
		||||
//	"github.com/gin-gonic/gin"
 | 
			
		||||
//	"github.com/go-redis/redis/v8"
 | 
			
		||||
//	"github.com/gorilla/websocket"
 | 
			
		||||
//	"gorm.io/gorm"
 | 
			
		||||
//	"net/http"
 | 
			
		||||
//	"strings"
 | 
			
		||||
//	"sync"
 | 
			
		||||
//	"time"
 | 
			
		||||
//)
 | 
			
		||||
//
 | 
			
		||||
//type SdJobHandler struct {
 | 
			
		||||
//	BaseHandler
 | 
			
		||||
//	redis           *redis.Client
 | 
			
		||||
//	db              *gorm.DB
 | 
			
		||||
//	mjService       *service.MjService
 | 
			
		||||
//	uploaderManager *oss.UploaderManager
 | 
			
		||||
//	lock            sync.Mutex
 | 
			
		||||
//	clients         *types.LMap[string, *types.WsClient]
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//func NewSdJobHandler(
 | 
			
		||||
//	app *core.AppServer,
 | 
			
		||||
//	client *redis.Client,
 | 
			
		||||
//	db *gorm.DB,
 | 
			
		||||
//	manager *oss.UploaderManager,
 | 
			
		||||
//	mjService *service.MjService) *MidJourneyHandler {
 | 
			
		||||
//	h := MidJourneyHandler{
 | 
			
		||||
//		redis:           client,
 | 
			
		||||
//		db:              db,
 | 
			
		||||
//		uploaderManager: manager,
 | 
			
		||||
//		lock:            sync.Mutex{},
 | 
			
		||||
//		mjService:       mjService,
 | 
			
		||||
//		clients:         types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
//	}
 | 
			
		||||
//	h.App = app
 | 
			
		||||
//	return &h
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
//func (h *SdJobHandler) Client(c *gin.Context) {
 | 
			
		||||
//	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		logger.Error(err)
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	sessionId := c.Query("session_id")
 | 
			
		||||
//	client := types.NewWsClient(ws)
 | 
			
		||||
//	// 删除旧的连接
 | 
			
		||||
//	h.clients.Delete(sessionId)
 | 
			
		||||
//	h.clients.Put(sessionId, client)
 | 
			
		||||
//	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//type sdNotifyData struct {
 | 
			
		||||
//	TaskId    string
 | 
			
		||||
//	ImageName string
 | 
			
		||||
//	ImageData string
 | 
			
		||||
//	Progress  int
 | 
			
		||||
//	Seed      string
 | 
			
		||||
//	Success   bool
 | 
			
		||||
//	Message   string
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//func (h *SdJobHandler) Notify(c *gin.Context) {
 | 
			
		||||
//	token := c.GetHeader("Authorization")
 | 
			
		||||
//	if token != h.App.Config.ExtConfig.Token {
 | 
			
		||||
//		resp.NotAuth(c)
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//	var data sdNotifyData
 | 
			
		||||
//	if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
 | 
			
		||||
//		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//	logger.Debugf("收到 MidJourney 回调请求:%+v", data)
 | 
			
		||||
//
 | 
			
		||||
//	h.lock.Lock()
 | 
			
		||||
//	defer h.lock.Unlock()
 | 
			
		||||
//
 | 
			
		||||
//	err, finished := h.notifyHandler(c, data)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		resp.ERROR(c, err.Error())
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	// 解除任务锁定
 | 
			
		||||
//	if finished && (data.Progress == 100) {
 | 
			
		||||
//		h.redis.Del(c, service.MjRunningJobKey)
 | 
			
		||||
//	}
 | 
			
		||||
//	resp.SUCCESS(c)
 | 
			
		||||
//
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
 | 
			
		||||
//	taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
 | 
			
		||||
//	if err != nil { // 过期任务,丢弃
 | 
			
		||||
//		logger.Warn("任务已过期:", err)
 | 
			
		||||
//		return nil, true
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	var task types.SdTask
 | 
			
		||||
//	err = utils.JsonDecode(taskString, &task)
 | 
			
		||||
//	if err != nil { // 非标准任务,丢弃
 | 
			
		||||
//		logger.Warn("任务解析失败:", err)
 | 
			
		||||
//		return nil, false
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	var job model.SdJob
 | 
			
		||||
//	res := h.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
//	if res.Error != nil {
 | 
			
		||||
//		logger.Warn("非法任务:", res.Error)
 | 
			
		||||
//		return nil, false
 | 
			
		||||
//	}
 | 
			
		||||
//	job.Params = utils.JsonEncode(task.Params)
 | 
			
		||||
//	job.ReferenceId = data.ImageData
 | 
			
		||||
//	job.Progress = data.Progress
 | 
			
		||||
//	job.Prompt = data.Prompt
 | 
			
		||||
//	job.Hash = data.Image.Hash
 | 
			
		||||
//
 | 
			
		||||
//	// 任务完成,将最终的图片下载下来
 | 
			
		||||
//	if data.Progress == 100 {
 | 
			
		||||
//		imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
			
		||||
//		if err != nil {
 | 
			
		||||
//			logger.Error("error with download img: ", err.Error())
 | 
			
		||||
//			return err, false
 | 
			
		||||
//		}
 | 
			
		||||
//		job.ImgURL = imgURL
 | 
			
		||||
//	} else {
 | 
			
		||||
//		// 临时图片直接保存,访问的时候使用代理进行转发
 | 
			
		||||
//		job.ImgURL = data.Image.URL
 | 
			
		||||
//	}
 | 
			
		||||
//	res = h.db.Updates(&job)
 | 
			
		||||
//	if res.Error != nil {
 | 
			
		||||
//		logger.Error("error with update job: ", res.Error)
 | 
			
		||||
//		return res.Error, false
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	var jobVo vo.MidJourneyJob
 | 
			
		||||
//	err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
//	if err == nil {
 | 
			
		||||
//		if data.Progress < 100 {
 | 
			
		||||
//			image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
 | 
			
		||||
//			if err == nil {
 | 
			
		||||
//				jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
//			}
 | 
			
		||||
//		}
 | 
			
		||||
//
 | 
			
		||||
//		// 推送任务到前端
 | 
			
		||||
//		client := h.clients.Get(task.SessionId)
 | 
			
		||||
//		if client != nil {
 | 
			
		||||
//			utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
//		}
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	// 更新用户剩余绘图次数
 | 
			
		||||
//	if data.Progress == 100 {
 | 
			
		||||
//		h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	return nil, true
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
//	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		resp.NotAuth(c)
 | 
			
		||||
//		return false
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	if user.ImgCalls <= 0 {
 | 
			
		||||
//		resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
 | 
			
		||||
//		return false
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	return true
 | 
			
		||||
//
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//// Image 创建一个绘画任务
 | 
			
		||||
//func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
//	var data struct {
 | 
			
		||||
//		SessionId string  `json:"session_id"`
 | 
			
		||||
//		Prompt    string  `json:"prompt"`
 | 
			
		||||
//		Rate      string  `json:"rate"`
 | 
			
		||||
//		Model     string  `json:"model"`
 | 
			
		||||
//		Chaos     int     `json:"chaos"`
 | 
			
		||||
//		Raw       bool    `json:"raw"`
 | 
			
		||||
//		Seed      int64   `json:"seed"`
 | 
			
		||||
//		Stylize   int     `json:"stylize"`
 | 
			
		||||
//		Img       string  `json:"img"`
 | 
			
		||||
//		Weight    float32 `json:"weight"`
 | 
			
		||||
//	}
 | 
			
		||||
//	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
//		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//	if !h.checkLimits(c) {
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	var prompt = data.Prompt
 | 
			
		||||
//	if data.Rate != "" && !strings.Contains(prompt, "--ar") {
 | 
			
		||||
//		prompt += " --ar " + data.Rate
 | 
			
		||||
//	}
 | 
			
		||||
//	if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
 | 
			
		||||
//		prompt += fmt.Sprintf(" --seed %d", data.Seed)
 | 
			
		||||
//	}
 | 
			
		||||
//	if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
 | 
			
		||||
//		prompt += fmt.Sprintf(" --s %d", data.Stylize)
 | 
			
		||||
//	}
 | 
			
		||||
//	if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
 | 
			
		||||
//		prompt += fmt.Sprintf(" --c %d", data.Chaos)
 | 
			
		||||
//	}
 | 
			
		||||
//	if data.Img != "" {
 | 
			
		||||
//		prompt = fmt.Sprintf("%s %s", data.Img, prompt)
 | 
			
		||||
//		if data.Weight > 0 {
 | 
			
		||||
//			prompt += fmt.Sprintf(" --iw %f", data.Weight)
 | 
			
		||||
//		}
 | 
			
		||||
//	}
 | 
			
		||||
//	if data.Raw {
 | 
			
		||||
//		prompt += " --style raw"
 | 
			
		||||
//	}
 | 
			
		||||
//	if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
 | 
			
		||||
//		prompt += data.Model
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
//	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
//	job := model.MidJourneyJob{
 | 
			
		||||
//		Type:      service.Image.String(),
 | 
			
		||||
//		UserId:    userId,
 | 
			
		||||
//		Progress:  0,
 | 
			
		||||
//		Prompt:    prompt,
 | 
			
		||||
//		CreatedAt: time.Now(),
 | 
			
		||||
//	}
 | 
			
		||||
//	if res := h.db.Create(&job); res.Error != nil {
 | 
			
		||||
//		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	h.mjService.PushTask(service.MjTask{
 | 
			
		||||
//		Id:        int(job.Id),
 | 
			
		||||
//		SessionId: data.SessionId,
 | 
			
		||||
//		Src:       service.TaskSrcImg,
 | 
			
		||||
//		Type:      service.Image,
 | 
			
		||||
//		Prompt:    prompt,
 | 
			
		||||
//		UserId:    userId,
 | 
			
		||||
//	})
 | 
			
		||||
//
 | 
			
		||||
//	var jobVo vo.MidJourneyJob
 | 
			
		||||
//	err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
//	if err == nil {
 | 
			
		||||
//		// 推送任务到前端
 | 
			
		||||
//		client := h.clients.Get(data.SessionId)
 | 
			
		||||
//		if client != nil {
 | 
			
		||||
//			utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
//		}
 | 
			
		||||
//	}
 | 
			
		||||
//	resp.SUCCESS(c)
 | 
			
		||||
//}
 | 
			
		||||
//
 | 
			
		||||
//// JobList 获取 MJ 任务列表
 | 
			
		||||
//func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
//	status := h.GetInt(c, "status", 0)
 | 
			
		||||
//	var items []model.MidJourneyJob
 | 
			
		||||
//	var res *gorm.DB
 | 
			
		||||
//	userId, _ := c.Get(types.LoginUserID)
 | 
			
		||||
//	if status == 1 {
 | 
			
		||||
//		res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
 | 
			
		||||
//	} else {
 | 
			
		||||
//		res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
 | 
			
		||||
//	}
 | 
			
		||||
//	if res.Error != nil {
 | 
			
		||||
//		resp.ERROR(c, types.NoData)
 | 
			
		||||
//		return
 | 
			
		||||
//	}
 | 
			
		||||
//
 | 
			
		||||
//	var jobs = make([]vo.MidJourneyJob, 0)
 | 
			
		||||
//	for _, item := range items {
 | 
			
		||||
//		var job vo.MidJourneyJob
 | 
			
		||||
//		err := utils.CopyObject(item, &job)
 | 
			
		||||
//		if err != nil {
 | 
			
		||||
//			continue
 | 
			
		||||
//		}
 | 
			
		||||
//		if item.Progress < 100 {
 | 
			
		||||
//			// 30 分钟还没完成的任务直接删除
 | 
			
		||||
//			if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
 | 
			
		||||
//				h.db.Delete(&item)
 | 
			
		||||
//				continue
 | 
			
		||||
//			}
 | 
			
		||||
//			if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
 | 
			
		||||
//				image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
 | 
			
		||||
//				if err == nil {
 | 
			
		||||
//					job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
//				}
 | 
			
		||||
//			}
 | 
			
		||||
//		}
 | 
			
		||||
//		jobs = append(jobs, job)
 | 
			
		||||
//	}
 | 
			
		||||
//	resp.SUCCESS(c, jobs)
 | 
			
		||||
//}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										35
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								api/main.go
									
									
									
									
									
								
							@@ -7,8 +7,10 @@ import (
 | 
			
		||||
	"chatplus/handler/admin"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/function"
 | 
			
		||||
	"chatplus/service/fun"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/service/wx"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"context"
 | 
			
		||||
	"embed"
 | 
			
		||||
@@ -107,7 +109,7 @@ func main() {
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 创建函数
 | 
			
		||||
		fx.Provide(function.NewFunctions),
 | 
			
		||||
		fx.Provide(fun.NewFunctions),
 | 
			
		||||
 | 
			
		||||
		// 创建控制器
 | 
			
		||||
		fx.Provide(handler.NewChatRoleHandler),
 | 
			
		||||
@@ -135,13 +137,36 @@ func main() {
 | 
			
		||||
			return service.NewCaptchaService(config.ApiConfig)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(oss.NewUploaderManager),
 | 
			
		||||
		fx.Provide(service.NewMjService),
 | 
			
		||||
		fx.Invoke(func(mjService *service.MjService) {
 | 
			
		||||
		fx.Provide(mj.NewService),
 | 
			
		||||
		fx.Invoke(func(mjService *mj.Service) {
 | 
			
		||||
			go func() {
 | 
			
		||||
				mjService.Run()
 | 
			
		||||
			}()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 微信机器人服务
 | 
			
		||||
		fx.Provide(wx.NewWeChatBot),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
 | 
			
		||||
			if config.WeChatBot {
 | 
			
		||||
				err := bot.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error("微信登录失败:", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// MidJourney 机器人
 | 
			
		||||
		fx.Provide(mj.NewBot),
 | 
			
		||||
		fx.Provide(mj.NewClient),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
 | 
			
		||||
			if config.MjConfig.Enabled {
 | 
			
		||||
				err := bot.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Fatal("MidJourney 服务启动失败:", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 注册路由
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/role/")
 | 
			
		||||
@@ -185,12 +210,10 @@ func main() {
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/reward/")
 | 
			
		||||
			group.POST("notify", h.Notify)
 | 
			
		||||
			group.POST("verify", h.Verify)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/mj/")
 | 
			
		||||
			group.POST("notify", h.Notify)
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.POST("upscale", h.Upscale)
 | 
			
		||||
			group.POST("variation", h.Variation)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,8 @@
 | 
			
		||||
package function
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -9,10 +10,10 @@ import (
 | 
			
		||||
 | 
			
		||||
type FuncMidJourney struct {
 | 
			
		||||
	name    string
 | 
			
		||||
	service *service.MjService
 | 
			
		||||
	service *mj.Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
 | 
			
		||||
func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney {
 | 
			
		||||
	return FuncMidJourney{
 | 
			
		||||
		name:    "MidJourney AI 绘画",
 | 
			
		||||
		service: mjService}
 | 
			
		||||
@@ -21,10 +22,10 @@ func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
 | 
			
		||||
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
 | 
			
		||||
	logger.Infof("MJ 绘画参数:%+v", params)
 | 
			
		||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
			
		||||
	f.service.PushTask(service.MjTask{
 | 
			
		||||
	f.service.PushTask(types.MjTask{
 | 
			
		||||
		SessionId: utils.InterfaceToString(params["session_id"]),
 | 
			
		||||
		Src:       service.TaskSrcChat,
 | 
			
		||||
		Type:      service.Image,
 | 
			
		||||
		Src:       types.TaskSrcChat,
 | 
			
		||||
		Type:      types.TaskImage,
 | 
			
		||||
		Prompt:    prompt,
 | 
			
		||||
		UserId:    utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
 | 
			
		||||
		RoleId:    utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
 | 
			
		||||
@@ -1,9 +1,9 @@
 | 
			
		||||
package function
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Function interface {
 | 
			
		||||
@@ -29,7 +29,7 @@ type dataItem struct {
 | 
			
		||||
	Remark string `json:"remark"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFunctions(config *types.AppConfig, mjService *service.MjService) map[string]Function {
 | 
			
		||||
func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function {
 | 
			
		||||
	return map[string]Function{
 | 
			
		||||
		types.FuncZaoBao:     NewZaoBao(config.ApiConfig),
 | 
			
		||||
		types.FuncWeibo:      NewWeiboHot(config.ApiConfig),
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package function
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package function
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package function
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
							
								
								
									
										213
									
								
								api/service/mj/bot.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								api/service/mj/bot.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,213 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"github.com/bwmarrin/discordgo"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MidJourney 机器人
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Bot struct {
 | 
			
		||||
	config  *types.MidJourneyConfig
 | 
			
		||||
	bot     *discordgo.Session
 | 
			
		||||
	service *Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
 | 
			
		||||
	discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.ProxyURL != "" {
 | 
			
		||||
		proxy, _ := url.Parse(config.ProxyURL)
 | 
			
		||||
		discord.Client = &http.Client{
 | 
			
		||||
			Transport: &http.Transport{
 | 
			
		||||
				Proxy: http.ProxyURL(proxy),
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		discord.Dialer = &websocket.Dialer{
 | 
			
		||||
			Proxy: http.ProxyURL(proxy),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Bot{
 | 
			
		||||
		config:  &config.MjConfig,
 | 
			
		||||
		bot:     discord,
 | 
			
		||||
		service: service,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) Run() error {
 | 
			
		||||
	b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
 | 
			
		||||
	b.bot.AddHandler(b.messageCreate)
 | 
			
		||||
	b.bot.AddHandler(b.messageUpdate)
 | 
			
		||||
 | 
			
		||||
	logger.Info("Starting MidJourney Bot...")
 | 
			
		||||
	err := b.bot.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("Error opening Discord connection:", err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("Starting MidJourney Bot successfully!")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskStatus string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	Start    = TaskStatus("Started")
 | 
			
		||||
	Running  = TaskStatus("Running")
 | 
			
		||||
	Stopped  = TaskStatus("Stopped")
 | 
			
		||||
	Finished = TaskStatus("Finished")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Image struct {
 | 
			
		||||
	URL      string `json:"url"`
 | 
			
		||||
	ProxyURL string `json:"proxy_url"`
 | 
			
		||||
	Filename string `json:"filename"`
 | 
			
		||||
	Width    int    `json:"width"`
 | 
			
		||||
	Height   int    `json:"height"`
 | 
			
		||||
	Size     int    `json:"size"`
 | 
			
		||||
	Hash     string `json:"hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
			
		||||
	// ignore messages for other channels
 | 
			
		||||
	if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("CREATE: %s", utils.JsonEncode(m))
 | 
			
		||||
	var referenceId = ""
 | 
			
		||||
	if m.ReferencedMessage != nil {
 | 
			
		||||
		referenceId = m.ReferencedMessage.ID
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
 | 
			
		||||
		// parse content
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
			Content:     m.Content,
 | 
			
		||||
			Progress:    0,
 | 
			
		||||
			Status:      Start}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
	// ignore messages for other channels
 | 
			
		||||
	if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
 | 
			
		||||
 | 
			
		||||
	var referenceId = ""
 | 
			
		||||
	if m.ReferencedMessage != nil {
 | 
			
		||||
		referenceId = m.ReferencedMessage.ID
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Stopped)") {
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
			Content:     m.Content,
 | 
			
		||||
			Progress:    extractProgress(m.Content),
 | 
			
		||||
			Status:      Stopped}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
			
		||||
	progress := extractProgress(content)
 | 
			
		||||
	var status TaskStatus
 | 
			
		||||
	if progress == 100 {
 | 
			
		||||
		status = Finished
 | 
			
		||||
	} else {
 | 
			
		||||
		status = Running
 | 
			
		||||
	}
 | 
			
		||||
	for _, attachment := range attachments {
 | 
			
		||||
		if attachment.Width == 0 || attachment.Height == 0 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		image := Image{
 | 
			
		||||
			URL:      attachment.URL,
 | 
			
		||||
			Height:   attachment.Height,
 | 
			
		||||
			ProxyURL: attachment.ProxyURL,
 | 
			
		||||
			Width:    attachment.Width,
 | 
			
		||||
			Size:     attachment.Size,
 | 
			
		||||
			Filename: attachment.Filename,
 | 
			
		||||
			Hash:     extractHashFromFilename(attachment.Filename),
 | 
			
		||||
		}
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			MessageId:   messageId,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Image:       image,
 | 
			
		||||
			Prompt:      extractPrompt(content),
 | 
			
		||||
			Content:     content,
 | 
			
		||||
			Progress:    progress,
 | 
			
		||||
			Status:      status,
 | 
			
		||||
		}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		break // only get one image
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// extract prompt from string
 | 
			
		||||
func extractPrompt(input string) string {
 | 
			
		||||
	pattern := `\*\*(.*?)\*\*`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindStringSubmatch(input)
 | 
			
		||||
	if len(matches) > 1 {
 | 
			
		||||
		return strings.TrimSpace(matches[1])
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractProgress(input string) int {
 | 
			
		||||
	pattern := `\((\d+)\%\)`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindStringSubmatch(input)
 | 
			
		||||
	if len(matches) > 1 {
 | 
			
		||||
		return utils.IntValue(matches[1], 0)
 | 
			
		||||
	}
 | 
			
		||||
	return 100
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractHashFromFilename(filename string) string {
 | 
			
		||||
	if !strings.HasSuffix(filename, ".png") {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	index := strings.LastIndex(filename, "_")
 | 
			
		||||
	if index != -1 {
 | 
			
		||||
		return filename[index+1 : len(filename)-4]
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										144
									
								
								api/service/mj/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								api/service/mj/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,144 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MidJourney client
 | 
			
		||||
 | 
			
		||||
type Client struct {
 | 
			
		||||
	client *req.Client
 | 
			
		||||
	config *types.MidJourneyConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClient(config *types.AppConfig) *Client {
 | 
			
		||||
	client := req.C().SetTimeout(10 * time.Second)
 | 
			
		||||
	// set proxy URL
 | 
			
		||||
	if config.ProxyURL != "" {
 | 
			
		||||
		client.SetProxyURL(config.ProxyURL)
 | 
			
		||||
	}
 | 
			
		||||
	return &Client{client: client, config: &config.MjConfig}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Imagine(prompt string) error {
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          2,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.config.GuildId,
 | 
			
		||||
		ChannelID:     c.config.ChanelId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"version": "1118961510123847772",
 | 
			
		||||
			"id":      "938956540159881230",
 | 
			
		||||
			"name":    "imagine",
 | 
			
		||||
			"type":    "1",
 | 
			
		||||
			"options": []map[string]any{
 | 
			
		||||
				{
 | 
			
		||||
					"type":  3,
 | 
			
		||||
					"name":  "prompt",
 | 
			
		||||
					"value": prompt,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"application_command": map[string]any{
 | 
			
		||||
				"id":                         "938956540159881230",
 | 
			
		||||
				"application_id":             ApplicationID,
 | 
			
		||||
				"version":                    "1118961510123847772",
 | 
			
		||||
				"default_permission":         true,
 | 
			
		||||
				"default_member_permissions": nil,
 | 
			
		||||
				"type":                       1,
 | 
			
		||||
				"nsfw":                       false,
 | 
			
		||||
				"name":                       "imagine",
 | 
			
		||||
				"description":                "Create images with Midjourney",
 | 
			
		||||
				"dm_permission":              true,
 | 
			
		||||
				"options": []map[string]any{
 | 
			
		||||
					{
 | 
			
		||||
						"type":        3,
 | 
			
		||||
						"name":        "prompt",
 | 
			
		||||
						"description": "The prompt to imagine",
 | 
			
		||||
						"required":    true,
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
				"attachments": []any{},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := "https://discord.com/api/v9/interactions"
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		Post(url)
 | 
			
		||||
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %w%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *Client) Upscale(index int, messageId string, hash string) error {
 | 
			
		||||
	flags := 0
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          3,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.config.GuildId,
 | 
			
		||||
		ChannelID:     c.config.ChanelId,
 | 
			
		||||
		MessageFlags:  &flags,
 | 
			
		||||
		MessageID:     &messageId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"component_type": 2,
 | 
			
		||||
			"custom_id":      fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
 | 
			
		||||
		},
 | 
			
		||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := "https://discord.com/api/v9/interactions"
 | 
			
		||||
	var res InteractionsResult
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		SetErrorResult(&res).
 | 
			
		||||
		Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *Client) Variation(index int, messageId string, hash string) error {
 | 
			
		||||
	flags := 0
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          3,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.config.GuildId,
 | 
			
		||||
		ChannelID:     c.config.ChanelId,
 | 
			
		||||
		MessageFlags:  &flags,
 | 
			
		||||
		MessageID:     &messageId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"component_type": 2,
 | 
			
		||||
			"custom_id":      fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
 | 
			
		||||
		},
 | 
			
		||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := "https://discord.com/api/v9/interactions"
 | 
			
		||||
	var res InteractionsResult
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		SetErrorResult(&res).
 | 
			
		||||
		Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										249
									
								
								api/service/mj/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										249
									
								
								api/service/mj/service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,249 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MJ 绘画服务
 | 
			
		||||
 | 
			
		||||
const RunningJobKey = "MidJourney_Running_Job"
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	client        *Client
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	redis         *redis.Client
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	Clients       *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
 | 
			
		||||
	proxyURL      string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		redis:         redisCli,
 | 
			
		||||
		db:            db,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
 | 
			
		||||
		client:        client,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		Clients:       types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		proxyURL:      config.ProxyURL,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Info("Starting MidJourney job consumer.")
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	for {
 | 
			
		||||
		_, err := s.redis.Get(ctx, RunningJobKey).Result()
 | 
			
		||||
		if err == nil { // 队列串行执行
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err = s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Consuming Task: %+v", task)
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			err = s.client.Imagine(task.Prompt)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
 | 
			
		||||
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err)
 | 
			
		||||
			if task.RetryCount <= 5 {
 | 
			
		||||
				s.taskQueue.RPush(task)
 | 
			
		||||
			}
 | 
			
		||||
			task.RetryCount += 1
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 更新任务的执行状态
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
 | 
			
		||||
		// 锁定任务执行通道,直到任务超时(5分钟)
 | 
			
		||||
		s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Infof("add a new MidJourney Task: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(data CBReq) {
 | 
			
		||||
	taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
 | 
			
		||||
	if err != nil { // 过期任务,丢弃
 | 
			
		||||
		logger.Warn("任务已过期:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var task types.MjTask
 | 
			
		||||
	err = utils.JsonDecode(taskString, &task)
 | 
			
		||||
	if err != nil { // 非标准任务,丢弃
 | 
			
		||||
		logger.Warn("任务解析失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	res := s.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
			
		||||
	if res.Error == nil && data.Status == Finished {
 | 
			
		||||
		logger.Warn("重复消息:", data.MessageId)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if task.Src == types.TaskSrcImg { // 绘画任务
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		res := s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Warn("非法任务:", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		job.MessageId = data.MessageId
 | 
			
		||||
		job.ReferenceId = data.ReferenceId
 | 
			
		||||
		job.Progress = data.Progress
 | 
			
		||||
		job.Prompt = data.Prompt
 | 
			
		||||
		job.Hash = data.Image.Hash
 | 
			
		||||
 | 
			
		||||
		// 任务完成,将最终的图片下载下来
 | 
			
		||||
		if data.Progress == 100 {
 | 
			
		||||
			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download img: ", err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
		} else {
 | 
			
		||||
			// 临时图片直接保存,访问的时候使用代理进行转发
 | 
			
		||||
			job.ImgURL = data.Image.URL
 | 
			
		||||
		}
 | 
			
		||||
		res = s.db.Updates(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update job: ", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var jobVo vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			if data.Progress < 100 {
 | 
			
		||||
				image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := s.Clients.Get(task.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	} else if task.Src == types.TaskSrcChat { // 聊天任务
 | 
			
		||||
		wsClient := s.ChatClients.Get(task.SessionId)
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
				content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
 | 
			
		||||
				utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			}
 | 
			
		||||
			// download image
 | 
			
		||||
			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
				if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
					content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
 | 
			
		||||
					utils.ReplyMessage(wsClient, content)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tx := s.db.Begin()
 | 
			
		||||
			data.Image.URL = imgURL
 | 
			
		||||
			message := model.HistoryMessage{
 | 
			
		||||
				UserId:     uint(task.UserId),
 | 
			
		||||
				ChatId:     task.ChatId,
 | 
			
		||||
				RoleId:     uint(task.RoleId),
 | 
			
		||||
				Type:       types.MjMsg,
 | 
			
		||||
				Icon:       task.Icon,
 | 
			
		||||
				Content:    utils.JsonEncode(data),
 | 
			
		||||
				Tokens:     0,
 | 
			
		||||
				UseContext: false,
 | 
			
		||||
			}
 | 
			
		||||
			res = tx.Create(&message)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("error with update database: ", err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// save the job
 | 
			
		||||
			job.UserId = task.UserId
 | 
			
		||||
			job.Type = task.Type.String()
 | 
			
		||||
			job.MessageId = data.MessageId
 | 
			
		||||
			job.ReferenceId = data.ReferenceId
 | 
			
		||||
			job.Prompt = data.Prompt
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
			job.Progress = data.Progress
 | 
			
		||||
			job.Hash = data.Image.Hash
 | 
			
		||||
			job.CreatedAt = time.Now()
 | 
			
		||||
			res = tx.Create(&job)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("error with update database: ", err)
 | 
			
		||||
				tx.Rollback()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			tx.Commit()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if wsClient == nil { // 客户端断线,则丢弃
 | 
			
		||||
			logger.Errorf("Client is offline: %+v", data)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
			// 本次绘画完毕,移除客户端
 | 
			
		||||
			s.ChatClients.Delete(task.SessionId)
 | 
			
		||||
		} else {
 | 
			
		||||
			// 使用代理临时转发图片
 | 
			
		||||
			if data.Image.URL != "" {
 | 
			
		||||
				image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户剩余绘图次数
 | 
			
		||||
	// TODO: 放大图片是否需要消耗绘图次数?
 | 
			
		||||
	if data.Status == Finished {
 | 
			
		||||
		s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
		// 解除任务锁定
 | 
			
		||||
		s.redis.Del(context.Background(), RunningJobKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										34
									
								
								api/service/mj/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								api/service/mj/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ApplicationID string = "936929561302675456"
 | 
			
		||||
	SessionID     string = "ea8816d857ba9ae2f74c59ae1a953afe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type InteractionsRequest struct {
 | 
			
		||||
	Type          int            `json:"type"`
 | 
			
		||||
	ApplicationID string         `json:"application_id"`
 | 
			
		||||
	MessageFlags  *int           `json:"message_flags,omitempty"`
 | 
			
		||||
	MessageID     *string        `json:"message_id,omitempty"`
 | 
			
		||||
	GuildID       string         `json:"guild_id"`
 | 
			
		||||
	ChannelID     string         `json:"channel_id"`
 | 
			
		||||
	SessionID     string         `json:"session_id"`
 | 
			
		||||
	Data          map[string]any `json:"data"`
 | 
			
		||||
	Nonce         string         `json:"nonce,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InteractionsResult struct {
 | 
			
		||||
	Code    int `json:"code"`
 | 
			
		||||
	Message string
 | 
			
		||||
	Error   map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	MessageId   string     `json:"message_id"`
 | 
			
		||||
	ReferenceId string     `json:"reference_id"`
 | 
			
		||||
	Image       Image      `json:"image"`
 | 
			
		||||
	Content     string     `json:"content"`
 | 
			
		||||
	Prompt      string     `json:"prompt"`
 | 
			
		||||
	Status      TaskStatus `json:"status"`
 | 
			
		||||
	Progress    int        `json:"progress"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,166 +0,0 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// MJ 绘画服务
 | 
			
		||||
 | 
			
		||||
const MjRunningJobKey = "MidJourney_Running_Job"
 | 
			
		||||
 | 
			
		||||
type MjService struct {
 | 
			
		||||
	config    types.ChatPlusExtConfig
 | 
			
		||||
	client    *req.Client
 | 
			
		||||
	taskQueue *store.RedisQueue
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMjService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *MjService {
 | 
			
		||||
	return &MjService{
 | 
			
		||||
		config:    appConfig.ExtConfig,
 | 
			
		||||
		redis:     client,
 | 
			
		||||
		db:        db,
 | 
			
		||||
		taskQueue: store.NewRedisQueue("midjourney_task_queue", client),
 | 
			
		||||
		client:    req.C().SetTimeout(30 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MjService) Run() {
 | 
			
		||||
	logger.Info("Starting MidJourney job consumer.")
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	for {
 | 
			
		||||
		_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
 | 
			
		||||
		if err == nil { // 队列串行执行
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err = s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Consuming Task: %+v", task)
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			err = s.image(task.Prompt)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			err = s.upscale(MjUpscaleReq{
 | 
			
		||||
				Index:       task.Index,
 | 
			
		||||
				MessageId:   task.MessageId,
 | 
			
		||||
				MessageHash: task.MessageHash,
 | 
			
		||||
			})
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			err = s.variation(MjVariationReq{
 | 
			
		||||
				Index:       task.Index,
 | 
			
		||||
				MessageId:   task.MessageId,
 | 
			
		||||
				MessageHash: task.MessageHash,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err)
 | 
			
		||||
			if task.RetryCount <= 5 {
 | 
			
		||||
				s.taskQueue.RPush(task)
 | 
			
		||||
			}
 | 
			
		||||
			task.RetryCount += 1
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 更新任务的执行状态
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
 | 
			
		||||
		// 锁定任务执行通道,直到任务超时(5分钟)
 | 
			
		||||
		s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MjService) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Infof("add a new MidJourney Task: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MjService) image(prompt string) error {
 | 
			
		||||
	logger.Infof("MJ 绘画参数:%+v", prompt)
 | 
			
		||||
	body := map[string]string{"prompt": prompt}
 | 
			
		||||
	url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := s.client.R().
 | 
			
		||||
		SetHeader("Authorization", s.config.Token).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjUpscaleReq struct {
 | 
			
		||||
	Index       int32  `json:"index"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MjService) upscale(upReq MjUpscaleReq) error {
 | 
			
		||||
	url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := s.client.R().
 | 
			
		||||
		SetHeader("Authorization", s.config.Token).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(upReq).
 | 
			
		||||
		SetSuccessResult(&res).Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MjVariationReq struct {
 | 
			
		||||
	Index       int32  `json:"index"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MjService) variation(upReq MjVariationReq) error {
 | 
			
		||||
	url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := s.client.R().
 | 
			
		||||
		SetHeader("Authorization", s.config.Token).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(upReq).
 | 
			
		||||
		SetSuccessResult(&res).Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										169
									
								
								api/service/sd/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								api/service/sd/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,169 @@
 | 
			
		||||
package sd
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Client struct {
 | 
			
		||||
	httpClient *req.Client
 | 
			
		||||
	config     *types.StableDiffusionConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSdClient(config *types.AppConfig) *Client {
 | 
			
		||||
	return &Client{
 | 
			
		||||
		config:     &config.SdConfig,
 | 
			
		||||
		httpClient: req.C(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Txt2Img(params types.SdTaskParams) error {
 | 
			
		||||
	var data []interface{}
 | 
			
		||||
	err := utils.JsonDecode(Text2ImgParamTemplate, &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	data[ParamKeys["task_id"]] = params.TaskId
 | 
			
		||||
	data[ParamKeys["prompt"]] = params.Prompt
 | 
			
		||||
	data[ParamKeys["negative_prompt"]] = params.NegativePrompt
 | 
			
		||||
	data[ParamKeys["steps"]] = params.Steps
 | 
			
		||||
	data[ParamKeys["sampler"]] = params.Sampler
 | 
			
		||||
	data[ParamKeys["face_fix"]] = params.FaceFix
 | 
			
		||||
	data[ParamKeys["cfg_scale"]] = params.CfgScale
 | 
			
		||||
	data[ParamKeys["seed"]] = params.Seed
 | 
			
		||||
	data[ParamKeys["height"]] = params.Height
 | 
			
		||||
	data[ParamKeys["width"]] = params.Width
 | 
			
		||||
	data[ParamKeys["hd_fix"]] = params.HdFix
 | 
			
		||||
	data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
 | 
			
		||||
	data[ParamKeys["hd_scale"]] = params.HdScale
 | 
			
		||||
	data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
 | 
			
		||||
	data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
 | 
			
		||||
	task := TaskInfo{
 | 
			
		||||
		TaskId:      params.TaskId,
 | 
			
		||||
		Data:        data,
 | 
			
		||||
		EventData:   nil,
 | 
			
		||||
		FnIndex:     494,
 | 
			
		||||
		SessionHash: "ycaxgzm9ah",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		c.runTask(task, c.httpClient)
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
 | 
			
		||||
	body := map[string]any{
 | 
			
		||||
		"data":         taskInfo.Data,
 | 
			
		||||
		"event_data":   taskInfo.EventData,
 | 
			
		||||
		"fn_index":     taskInfo.FnIndex,
 | 
			
		||||
		"session_hash": taskInfo.SessionHash,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var result = make(chan CBReq)
 | 
			
		||||
	go func() {
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Data            []interface{} `json:"data"`
 | 
			
		||||
			IsGenerating    bool          `json:"is_generating"`
 | 
			
		||||
			Duration        float64       `json:"duration"`
 | 
			
		||||
			AverageDuration float64       `json:"average_duration"`
 | 
			
		||||
		}
 | 
			
		||||
		var cbReq = CBReq{TaskId: taskInfo.TaskId}
 | 
			
		||||
		response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cbReq.Message = "error with send request: " + err.Error()
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if response.IsErrorState() {
 | 
			
		||||
			bytes, _ := io.ReadAll(response.Body)
 | 
			
		||||
			cbReq.Message = "error http status code: " + string(bytes)
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var images []struct {
 | 
			
		||||
			Name   string      `json:"name"`
 | 
			
		||||
			Data   interface{} `json:"data"`
 | 
			
		||||
			IsFile bool        `json:"is_file"`
 | 
			
		||||
		}
 | 
			
		||||
		err = utils.ForceCovert(res.Data[0], &images)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cbReq.Message = "error with decode image:" + err.Error()
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var info map[string]any
 | 
			
		||||
		err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cbReq.Message = err.Error()
 | 
			
		||||
			cbReq.Success = false
 | 
			
		||||
			result <- cbReq
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		//for k, v := range info {
 | 
			
		||||
		//	fmt.Println(k, " => ", v)
 | 
			
		||||
		//}
 | 
			
		||||
		cbReq.ImageName = images[0].Name
 | 
			
		||||
		cbReq.Seed = utils.InterfaceToString(info["seed"])
 | 
			
		||||
		cbReq.Success = true
 | 
			
		||||
		cbReq.Progress = 100
 | 
			
		||||
		result <- cbReq
 | 
			
		||||
		close(result)
 | 
			
		||||
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case value := <-result:
 | 
			
		||||
			if value.Success {
 | 
			
		||||
				logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		default:
 | 
			
		||||
			var progressReq = map[string]any{
 | 
			
		||||
				"id_task":         taskInfo.TaskId,
 | 
			
		||||
				"id_live_preview": 1,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var progressRes struct {
 | 
			
		||||
				Active        bool        `json:"active"`
 | 
			
		||||
				Queued        bool        `json:"queued"`
 | 
			
		||||
				Completed     bool        `json:"completed"`
 | 
			
		||||
				Progress      float64     `json:"progress"`
 | 
			
		||||
				Eta           float64     `json:"eta"`
 | 
			
		||||
				LivePreview   string      `json:"live_preview"`
 | 
			
		||||
				IDLivePreview int         `json:"id_live_preview"`
 | 
			
		||||
				TextInfo      interface{} `json:"textinfo"`
 | 
			
		||||
			}
 | 
			
		||||
			response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
 | 
			
		||||
			var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
 | 
			
		||||
			if err != nil { // TODO: 这里可以考虑设置失败重试次数
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if response.IsErrorState() {
 | 
			
		||||
				bytes, _ := io.ReadAll(response.Body)
 | 
			
		||||
				logger.Error(string(bytes))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			cbReq.ImageData = progressRes.LivePreview
 | 
			
		||||
			cbReq.Progress = int(progressRes.Progress * 100)
 | 
			
		||||
			fmt.Println("Progress: ", progressRes.Progress)
 | 
			
		||||
			fmt.Println("Image: ", progressRes.LivePreview)
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,45 +1,42 @@
 | 
			
		||||
package service
 | 
			
		||||
package sd
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SD 绘画服务
 | 
			
		||||
 | 
			
		||||
const SdRunningJobKey = "StableDiffusion_Running_Job"
 | 
			
		||||
const RunningJobKey = "StableDiffusion_Running_Job"
 | 
			
		||||
 | 
			
		||||
type SdService struct {
 | 
			
		||||
	config    types.ChatPlusExtConfig
 | 
			
		||||
	client    *req.Client
 | 
			
		||||
type Service struct {
 | 
			
		||||
	taskQueue *store.RedisQueue
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
	Client    *Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
 | 
			
		||||
	return &SdService{
 | 
			
		||||
		config:    appConfig.ExtConfig,
 | 
			
		||||
		redis:     client,
 | 
			
		||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		redis:     redisCli,
 | 
			
		||||
		db:        db,
 | 
			
		||||
		taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
 | 
			
		||||
		client:    req.C().SetTimeout(30 * time.Second)}
 | 
			
		||||
		Client:    client,
 | 
			
		||||
		taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *SdService) Run() {
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Info("Starting StableDiffusion job consumer.")
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	for {
 | 
			
		||||
		_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
 | 
			
		||||
		_, err := s.redis.Get(ctx, RunningJobKey).Result()
 | 
			
		||||
		if err == nil { // 队列串行执行
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
@@ -51,7 +48,7 @@ func (s *SdService) Run() {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Consuming Task: %+v", task)
 | 
			
		||||
		err = s.txt2img(task.Params)
 | 
			
		||||
		err = s.Client.Txt2Img(task.Params)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err)
 | 
			
		||||
			if task.RetryCount <= 5 {
 | 
			
		||||
@@ -65,31 +62,11 @@ func (s *SdService) Run() {
 | 
			
		||||
		// 更新任务的执行状态
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
 | 
			
		||||
		// 锁定任务执行通道,直到任务超时(5分钟)
 | 
			
		||||
		s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
 | 
			
		||||
		s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *SdService) PushTask(task types.SdTask) {
 | 
			
		||||
func (s *Service) PushTask(task types.SdTask) {
 | 
			
		||||
	logger.Infof("add a new MidJourney Task: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *SdService) txt2img(params types.SdParams) error {
 | 
			
		||||
	logger.Infof("SD 绘画参数:%+v", params)
 | 
			
		||||
	url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := s.client.R().
 | 
			
		||||
		SetHeader("Authorization", s.config.Token).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(params).
 | 
			
		||||
		SetSuccessResult(&res).Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("%v%v", r.String(), err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										234
									
								
								api/service/sd/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								api/service/sd/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,234 @@
 | 
			
		||||
package sd
 | 
			
		||||
 | 
			
		||||
import logger2 "chatplus/logger"
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type TaskInfo struct {
 | 
			
		||||
	TaskId      string      `json:"task_id"`
 | 
			
		||||
	Data        interface{} `json:"data"`
 | 
			
		||||
	EventData   interface{} `json:"event_data"`
 | 
			
		||||
	FnIndex     int         `json:"fn_index"`
 | 
			
		||||
	SessionHash string      `json:"session_hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	TaskId    string
 | 
			
		||||
	ImageName string
 | 
			
		||||
	ImageData string
 | 
			
		||||
	Progress  int
 | 
			
		||||
	Seed      string
 | 
			
		||||
	Success   bool
 | 
			
		||||
	Message   string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ParamKeys = map[string]int{
 | 
			
		||||
	"task_id":         0,
 | 
			
		||||
	"prompt":          1,
 | 
			
		||||
	"negative_prompt": 2,
 | 
			
		||||
	"steps":           4,
 | 
			
		||||
	"sampler":         5,
 | 
			
		||||
	"face_fix":        6,
 | 
			
		||||
	"cfg_scale":       10,
 | 
			
		||||
	"seed":            11,
 | 
			
		||||
	"height":          17,
 | 
			
		||||
	"width":           18,
 | 
			
		||||
	"hd_fix":          19,
 | 
			
		||||
	"hd_redraw_rate":  20, //高清修复重绘幅度
 | 
			
		||||
	"hd_scale":        21, // 高清修复放大倍数
 | 
			
		||||
	"hd_scale_alg":    22, // 高清修复放大算法
 | 
			
		||||
	"hd_sample_num":   23, // 高清修复采样次数
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const Text2ImgParamTemplate = `[
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
[],
 | 
			
		||||
30,
 | 
			
		||||
"DPM++ SDE Karras",
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
1,
 | 
			
		||||
1,
 | 
			
		||||
7.5,
 | 
			
		||||
-1,
 | 
			
		||||
-1,
 | 
			
		||||
0,
 | 
			
		||||
0,
 | 
			
		||||
0,
 | 
			
		||||
false,
 | 
			
		||||
512,
 | 
			
		||||
512,
 | 
			
		||||
true,
 | 
			
		||||
0.7,
 | 
			
		||||
2,
 | 
			
		||||
"Latent",
 | 
			
		||||
10,
 | 
			
		||||
0,
 | 
			
		||||
0,
 | 
			
		||||
"Use same sampler",
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
[],
 | 
			
		||||
"None",
 | 
			
		||||
false,
 | 
			
		||||
"MultiDiffusion",
 | 
			
		||||
false,
 | 
			
		||||
true,
 | 
			
		||||
1024,
 | 
			
		||||
1024,
 | 
			
		||||
96,
 | 
			
		||||
96,
 | 
			
		||||
48,
 | 
			
		||||
4,
 | 
			
		||||
"None",
 | 
			
		||||
2,
 | 
			
		||||
false,
 | 
			
		||||
10,
 | 
			
		||||
1,
 | 
			
		||||
1,
 | 
			
		||||
64,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
0.4,
 | 
			
		||||
0.4,
 | 
			
		||||
0.2,
 | 
			
		||||
0.2,
 | 
			
		||||
"",
 | 
			
		||||
"",
 | 
			
		||||
"Background",
 | 
			
		||||
0.2,
 | 
			
		||||
-1,
 | 
			
		||||
false,
 | 
			
		||||
3072,
 | 
			
		||||
192,
 | 
			
		||||
true,
 | 
			
		||||
true,
 | 
			
		||||
true,
 | 
			
		||||
false,
 | 
			
		||||
null,
 | 
			
		||||
null,
 | 
			
		||||
null,
 | 
			
		||||
false,
 | 
			
		||||
"",
 | 
			
		||||
0.5,
 | 
			
		||||
true,
 | 
			
		||||
false,
 | 
			
		||||
"",
 | 
			
		||||
"Lerp",
 | 
			
		||||
false,
 | 
			
		||||
"🔄",
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
"positive",
 | 
			
		||||
"comma",
 | 
			
		||||
0,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
"",
 | 
			
		||||
"Seed",
 | 
			
		||||
"",
 | 
			
		||||
[],
 | 
			
		||||
"Nothing",
 | 
			
		||||
"",
 | 
			
		||||
[],
 | 
			
		||||
"Nothing",
 | 
			
		||||
"",
 | 
			
		||||
[],
 | 
			
		||||
true,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
false,
 | 
			
		||||
0,
 | 
			
		||||
null,
 | 
			
		||||
null,
 | 
			
		||||
false,
 | 
			
		||||
null,
 | 
			
		||||
null,
 | 
			
		||||
false,
 | 
			
		||||
null,
 | 
			
		||||
null,
 | 
			
		||||
false,
 | 
			
		||||
50
 | 
			
		||||
]`
 | 
			
		||||
							
								
								
									
										87
									
								
								api/service/wx/bot.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								api/service/wx/bot.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,87 @@
 | 
			
		||||
package wx
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"github.com/eatmoreapple/openwechat"
 | 
			
		||||
	"github.com/skip2/go-qrcode"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 微信收款机器人
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Bot struct {
 | 
			
		||||
	bot   *openwechat.Bot
 | 
			
		||||
	token string
 | 
			
		||||
	db    *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWeChatBot(db *gorm.DB) *Bot {
 | 
			
		||||
	bot := openwechat.DefaultBot(openwechat.Desktop)
 | 
			
		||||
	return &Bot{
 | 
			
		||||
		bot: bot,
 | 
			
		||||
		db:  db,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) Run() error {
 | 
			
		||||
	logger.Info("Starting WeChat Bot...")
 | 
			
		||||
 | 
			
		||||
	// set message handler
 | 
			
		||||
	b.bot.MessageHandler = func(msg *openwechat.Message) {
 | 
			
		||||
		b.messageHandler(msg)
 | 
			
		||||
	}
 | 
			
		||||
	// scan code login callback
 | 
			
		||||
	b.bot.UUIDCallback = b.qrCodeCallBack
 | 
			
		||||
 | 
			
		||||
	err := b.bot.Login()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Info("微信登录成功!")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// message handler
 | 
			
		||||
func (b *Bot) messageHandler(msg *openwechat.Message) {
 | 
			
		||||
	sender, err := msg.Sender()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 只处理微信支付的推送消息
 | 
			
		||||
	if sender.NickName == "微信支付" ||
 | 
			
		||||
		msg.MsgType == openwechat.MsgTypeApp ||
 | 
			
		||||
		msg.AppMsgType == openwechat.AppMsgTypeUrl {
 | 
			
		||||
		// 解析支付金额
 | 
			
		||||
		message, err := parseTransactionMessage(msg.Content)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			transaction := extractTransaction(message)
 | 
			
		||||
			logger.Infof("解析到收款信息:%+v", transaction)
 | 
			
		||||
			var item model.Reward
 | 
			
		||||
			res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
 | 
			
		||||
			if res.Error == nil {
 | 
			
		||||
				logger.Error("当前交易 ID 己经存在!")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			res = b.db.Create(&model.Reward{
 | 
			
		||||
				TxId:   transaction.TransId,
 | 
			
		||||
				Amount: transaction.Amount,
 | 
			
		||||
				Remark: transaction.Remark,
 | 
			
		||||
				Status: false,
 | 
			
		||||
			})
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Errorf("交易保存失败: %v", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) qrCodeCallBack(uuid string) {
 | 
			
		||||
	logger.Info("请使用微信扫描下面二维码登录")
 | 
			
		||||
	q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
 | 
			
		||||
	logger.Info(q.ToString(true))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										68
									
								
								api/service/wx/tranaction.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								api/service/wx/tranaction.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
package wx
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Message 转账消息
 | 
			
		||||
type Message struct {
 | 
			
		||||
	XMLName xml.Name `xml:"msg"`
 | 
			
		||||
	AppMsg  struct {
 | 
			
		||||
		Des string `xml:"des"`
 | 
			
		||||
		Url string `xml:"url"`
 | 
			
		||||
	} `xml:"appmsg"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Transaction 解析后的交易信息
 | 
			
		||||
type Transaction struct {
 | 
			
		||||
	TransId string  `json:"trans_id"` // 微信转账交易 ID
 | 
			
		||||
	Amount  float64 `json:"amount"`   // 微信转账交易金额
 | 
			
		||||
	Remark  string  `json:"remark"`   // 转账备注
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 解析微信转账消息
 | 
			
		||||
func parseTransactionMessage(xmlData string) (*Message, error) {
 | 
			
		||||
	var msg Message
 | 
			
		||||
	if err := xml.Unmarshal([]byte(xmlData), &msg); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &msg, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 导出交易信息
 | 
			
		||||
func extractTransaction(message *Message) Transaction {
 | 
			
		||||
	var tx = Transaction{}
 | 
			
		||||
	// 导出交易金额和备注
 | 
			
		||||
	lines := strings.Split(message.AppMsg.Des, "\n")
 | 
			
		||||
	for _, line := range lines {
 | 
			
		||||
		line = strings.TrimSpace(line)
 | 
			
		||||
		if len(line) == 0 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// 解析收款金额
 | 
			
		||||
		prefix := "收款金额¥"
 | 
			
		||||
		if strings.HasPrefix(line, prefix) {
 | 
			
		||||
			if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
 | 
			
		||||
				tx.Amount = value
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// 解析收款备注
 | 
			
		||||
		prefix = "付款方备注"
 | 
			
		||||
		if strings.HasPrefix(line, prefix) {
 | 
			
		||||
			tx.Remark = line[len(prefix):]
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 解析交易 ID
 | 
			
		||||
	index := strings.Index(message.AppMsg.Url, "trans_id=")
 | 
			
		||||
	if index != -1 {
 | 
			
		||||
		end := strings.LastIndex(message.AppMsg.Url, "&")
 | 
			
		||||
		tx.TransId = strings.TrimSpace(message.AppMsg.Url[index+9 : end])
 | 
			
		||||
	}
 | 
			
		||||
	return tx
 | 
			
		||||
}
 | 
			
		||||
@@ -6,14 +6,14 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SdJob struct {
 | 
			
		||||
	Id        uint           `json:"id"`
 | 
			
		||||
	Type      string         `json:"type"`
 | 
			
		||||
	UserId    int            `json:"user_id"`
 | 
			
		||||
	TaskId    string         `json:"task_id"`
 | 
			
		||||
	ImgURL    string         `json:"img_url"`
 | 
			
		||||
	Params    types.SdParams `json:"params"`
 | 
			
		||||
	Progress  int            `json:"progress"`
 | 
			
		||||
	Prompt    string         `json:"prompt"`
 | 
			
		||||
	CreatedAt time.Time      `json:"created_at"`
 | 
			
		||||
	Started   bool           `json:"started"`
 | 
			
		||||
	Id        uint               `json:"id"`
 | 
			
		||||
	Type      string             `json:"type"`
 | 
			
		||||
	UserId    int                `json:"user_id"`
 | 
			
		||||
	TaskId    string             `json:"task_id"`
 | 
			
		||||
	ImgURL    string             `json:"img_url"`
 | 
			
		||||
	Params    types.SdTaskParams `json:"params"`
 | 
			
		||||
	Progress  int                `json:"progress"`
 | 
			
		||||
	Prompt    string             `json:"prompt"`
 | 
			
		||||
	CreatedAt time.Time          `json:"created_at"`
 | 
			
		||||
	Started   bool               `json:"started"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -138,3 +138,15 @@ func IntValue(str string, defaultValue int) int {
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ForceCovert(src any, dst interface{}) error {
 | 
			
		||||
	bytes, err := json.Marshal(src)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(bytes, dst)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user