diff --git a/api/core/app_server.go b/api/core/app_server.go index 9da122d4..aa7016bd 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -2,7 +2,7 @@ package core import ( "chatplus/core/types" - "chatplus/service/function" + "chatplus/service/fun" "chatplus/store/model" "chatplus/utils" "chatplus/utils/resp" @@ -33,11 +33,10 @@ type AppServer struct { ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function - Functions map[string]function.Function - MjTaskClients *types.LMap[string, *types.WsClient] + Functions map[string]fun.Function } -func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer { +func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer { gin.SetMode(gin.ReleaseMode) gin.DefaultWriter = io.Discard return &AppServer{ @@ -48,7 +47,6 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio ChatSession: types.NewLMap[string, *types.ChatSession](), ChatClients: types.NewLMap[string, *types.WsClient](), ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), - MjTaskClients: types.NewLMap[string, *types.WsClient](), Functions: functions, } } diff --git a/api/core/config.go b/api/core/config.go index a6a5d956..28be3cad 100644 --- a/api/core/config.go +++ b/api/core/config.go @@ -26,7 +26,6 @@ func NewDefaultConfig() *types.AppConfig { MaxAge: 86400, }, ApiConfig: types.ChatPlusApiConfig{}, - ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)}, OSS: types.OSSConfig{ Active: "local", Local: types.LocalStorageConfig{ @@ -34,6 +33,9 @@ func NewDefaultConfig() *types.AppConfig { BasePath: "./static/upload", }, }, + MjConfig: types.MidJourneyConfig{Enabled: false}, + SdConfig: types.StableDiffusionConfig{Enabled: false}, + WeChatBot: false, } } diff --git a/api/core/types/config.go b/api/core/types/config.go index ac1e2295..10aaa04a 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -16,10 +16,11 @@ type AppConfig struct { Redis RedisConfig // redis 连接信息 ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs AesEncryptKey string - SmsConfig AliYunSmsConfig // AliYun send message service config - ExtConfig ChatPlusExtConfig // ChatPlus extensions callback api config - - OSS OSSConfig // OSS config + SmsConfig AliYunSmsConfig // AliYun send message service config + OSS OSSConfig // OSS config + MjConfig MidJourneyConfig // mj 绘画配置 + WeChatBot bool // 是否启用微信机器人 + SdConfig StableDiffusionConfig // sd 绘画配置 } type ChatPlusApiConfig struct { @@ -28,9 +29,22 @@ type ChatPlusApiConfig struct { Token string } -type ChatPlusExtConfig struct { - ApiURL string - Token string +type MidJourneyConfig struct { + Enabled bool + UserToken string + BotToken string + GuildId string // Server ID + ChanelId string // Chanel ID +} + +type WeChatConfig struct { + Enabled bool +} + +type StableDiffusionConfig struct { + Enabled bool + ApiURL string + ApiKey string } type AliYunSmsConfig struct { diff --git a/api/core/types/task.go b/api/core/types/task.go index 4ced8ec7..a91d30cb 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -33,14 +33,24 @@ type MjTask struct { ChatId string `json:"chat_id,omitempty"` RoleId int `json:"role_id,omitempty"` Icon string `json:"icon,omitempty"` - Index int32 `json:"index,omitempty"` + Index int `json:"index,omitempty"` MessageId string `json:"message_id,omitempty"` MessageHash string `json:"message_hash,omitempty"` RetryCount int `json:"retry_count"` } -// SdParams stable diffusion 绘画参数 -type SdParams struct { +type SdTask struct { + Id int `json:"id"` + SessionId string `json:"session_id"` + Src TaskSrc `json:"src"` + Type TaskType `json:"type"` + UserId int `json:"user_id"` + Prompt string `json:"prompt,omitempty"` + Params SdTaskParams `json:"params"` + RetryCount int `json:"retry_count"` +} + +type SdTaskParams struct { TaskId string `json:"task_id"` Prompt string `json:"prompt"` NegativePrompt string `json:"negative_prompt"` @@ -57,14 +67,3 @@ type SdParams struct { HdScaleAlg string `json:"hd_scale_alg"` HdSampleNum int `json:"hd_sample_num"` } - -type SdTask struct { - Id int `json:"id"` - SessionId string `json:"session_id"` - Src types.TaskSrc `json:"src"` - Type types.TaskType `json:"type"` - UserId int `json:"user_id"` - Prompt string `json:"prompt,omitempty"` - Params types.SdParams `json:"params"` - RetryCount int `json:"retry_count"` -} diff --git a/api/go.mod b/api/go.mod index 2b6d41f1..2795d11c 100644 --- a/api/go.mod +++ b/api/go.mod @@ -5,6 +5,9 @@ go 1.19 require ( github.com/BurntSushi/toml v1.1.0 github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 + github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible + github.com/bwmarrin/discordgo v0.27.1 + github.com/eatmoreapple/openwechat v1.2.1 github.com/gin-gonic/gin v1.9.1 github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt/v5 v5.0.0 @@ -14,6 +17,7 @@ require ( github.com/minio/minio-go/v7 v7.0.62 github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 github.com/qiniu/go-sdk/v7 v7.17.1 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/syndtr/goleveldb v1.0.0 go.uber.org/zap v1.23.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -21,7 +25,6 @@ require ( ) require ( - github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible // indirect github.com/andybalholm/brotli v1.0.4 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/api/go.sum b/api/go.sum index 0a70646e..108f5b59 100644 --- a/api/go.sum +++ b/api/go.sum @@ -7,6 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9 github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= +github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY= +github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -25,6 +27,8 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0 github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU= +github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= @@ -75,6 +79,7 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -168,6 +173,8 @@ github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -209,6 +216,7 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= diff --git a/api/handler/azure_handler.go b/api/handler/azure_handler.go index a87f8e00..9b855bea 100644 --- a/api/handler/azure_handler.go +++ b/api/handler/azure_handler.go @@ -150,7 +150,7 @@ func (h *ChatHandler) sendAzureMessage( content := data if functionName == types.FuncMidJourney { content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) - h.App.MjTaskClients.Put(session.SessionId, ws) + h.mjService.ChatClients.Put(session.SessionId, ws) // update user's img_calls h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 0ea51f2a..b9241b1e 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -4,6 +4,7 @@ import ( "bytes" "chatplus/core" "chatplus/core/types" + "chatplus/service/mj" "chatplus/store" "chatplus/store/model" "chatplus/store/vo" @@ -27,13 +28,14 @@ const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。" type ChatHandler struct { BaseHandler - db *gorm.DB - leveldb *store.LevelDB - redis *redis.Client + db *gorm.DB + leveldb *store.LevelDB + redis *redis.Client + mjService *mj.Service } -func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client) *ChatHandler { - handler := ChatHandler{db: db, leveldb: levelDB, redis: redis} +func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler { + handler := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service} handler.App = app return &handler } diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index d0b18fc5..a0c270b2 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -3,7 +3,7 @@ package handler import ( "chatplus/core" "chatplus/core/types" - "chatplus/service" + "chatplus/service/mj" "chatplus/service/oss" "chatplus/store/model" "chatplus/store/vo" @@ -21,28 +21,11 @@ import ( "time" ) -type TaskStatus string - -const ( - Stopped = TaskStatus("Stopped") - Finished = TaskStatus("Finished") -) - -type Image struct { - URL string `json:"url"` - ProxyURL string `json:"proxy_url"` - Filename string `json:"filename"` - Width int `json:"width"` - Height int `json:"height"` - Size int `json:"size"` - Hash string `json:"hash"` -} - type MidJourneyHandler struct { BaseHandler redis *redis.Client db *gorm.DB - mjService *service.MjService + mjService *mj.Service uploaderManager *oss.UploaderManager lock sync.Mutex clients *types.LMap[string, *types.WsClient] @@ -53,7 +36,7 @@ func NewMidJourneyHandler( client *redis.Client, db *gorm.DB, manager *oss.UploaderManager, - mjService *service.MjService) *MidJourneyHandler { + mjService *mj.Service) *MidJourneyHandler { h := MidJourneyHandler{ redis: client, db: db, @@ -66,16 +49,6 @@ func NewMidJourneyHandler( return &h } -type mjNotifyData struct { - MessageId string `json:"message_id"` - ReferenceId string `json:"reference_id"` - Image Image `json:"image"` - Content string `json:"content"` - Prompt string `json:"prompt"` - Status TaskStatus `json:"status"` - Progress int `json:"progress"` -} - // Client WebSocket 客户端,用于通知任务状态变更 func (h *MidJourneyHandler) Client(c *gin.Context) { ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) @@ -92,189 +65,6 @@ func (h *MidJourneyHandler) Client(c *gin.Context) { logger.Infof("New websocket connected, IP: %s", c.ClientIP()) } -func (h *MidJourneyHandler) Notify(c *gin.Context) { - token := c.GetHeader("Authorization") - if token != h.App.Config.ExtConfig.Token { - resp.NotAuth(c) - return - } - var data mjNotifyData - if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { - resp.ERROR(c, types.InvalidArgs) - return - } - logger.Debugf("收到 MidJourney 回调请求:%+v", data) - - h.lock.Lock() - defer h.lock.Unlock() - - err, finished := h.notifyHandler(c, data) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - - // 解除任务锁定 - if finished && (data.Status == Finished || data.Status == Stopped) { - h.redis.Del(c, service.MjRunningJobKey) - } - resp.SUCCESS(c) - -} - -func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) { - taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() - if err != nil { // 过期任务,丢弃 - logger.Warn("任务已过期:", err) - return nil, true - } - - var task types.MjTask - err = utils.JsonDecode(taskString, &task) - if err != nil { // 非标准任务,丢弃 - logger.Warn("任务解析失败:", err) - return nil, false - } - - var job model.MidJourneyJob - res := h.db.Where("message_id = ?", data.MessageId).First(&job) - if res.Error == nil && data.Status == Finished { - logger.Warn("重复消息:", data.MessageId) - return nil, false - } - - if task.Src == types.TaskSrcImg { // 绘画任务 - var job model.MidJourneyJob - res := h.db.Where("id = ?", task.Id).First(&job) - if res.Error != nil { - logger.Warn("非法任务:", res.Error) - return nil, false - } - job.MessageId = data.MessageId - job.ReferenceId = data.ReferenceId - job.Progress = data.Progress - job.Prompt = data.Prompt - job.Hash = data.Image.Hash - - // 任务完成,将最终的图片下载下来 - if data.Progress == 100 { - imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) - if err != nil { - logger.Error("error with download img: ", err.Error()) - return err, false - } - job.ImgURL = imgURL - } else { - // 临时图片直接保存,访问的时候使用代理进行转发 - job.ImgURL = data.Image.URL - } - res = h.db.Updates(&job) - if res.Error != nil { - logger.Error("error with update job: ", res.Error) - return res.Error, false - } - - var jobVo vo.MidJourneyJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - if data.Progress < 100 { - image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) - if err == nil { - jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } - } - - // 推送任务到前端 - client := h.clients.Get(task.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } - - } else if task.Src == types.TaskSrcChat { // 聊天任务 - wsClient := h.App.MjTaskClients.Get(task.SessionId) - if data.Status == Finished { - if wsClient != nil && data.ReferenceId != "" { - content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) - utils.ReplyMessage(wsClient, content) - } - // download image - imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) - if err != nil { - logger.Error("error with download image: ", err) - if wsClient != nil && data.ReferenceId != "" { - content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) - utils.ReplyMessage(wsClient, content) - } - return err, false - } - - tx := h.db.Begin() - data.Image.URL = imgURL - message := model.HistoryMessage{ - UserId: uint(task.UserId), - ChatId: task.ChatId, - RoleId: uint(task.RoleId), - Type: types.MjMsg, - Icon: task.Icon, - Content: utils.JsonEncode(data), - Tokens: 0, - UseContext: false, - } - res = tx.Create(&message) - if res.Error != nil { - return res.Error, false - } - - // save the job - job.UserId = task.UserId - job.Type = task.Type.String() - job.MessageId = data.MessageId - job.ReferenceId = data.ReferenceId - job.Prompt = data.Prompt - job.ImgURL = imgURL - job.Progress = data.Progress - job.Hash = data.Image.Hash - job.CreatedAt = time.Now() - res = tx.Create(&job) - if res.Error != nil { - tx.Rollback() - return res.Error, false - } - tx.Commit() - } - - if wsClient == nil { // 客户端断线,则丢弃 - logger.Errorf("Client is offline: %+v", data) - return nil, true - } - - if data.Status == Finished { - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) - // 本次绘画完毕,移除客户端 - h.App.MjTaskClients.Delete(task.SessionId) - } else { - // 使用代理临时转发图片 - if data.Image.URL != "" { - image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL) - if err == nil { - data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } - } - utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) - } - } - - // 更新用户剩余绘图次数 - // TODO: 放大图片是否需要消耗绘图次数? - if data.Status == Finished { - h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) - } - - return nil, true -} - func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool { user, err := utils.GetLoginUser(c, h.db) if err != nil { @@ -376,7 +166,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { type reqVo struct { Src string `json:"src"` - Index int32 `json:"index"` + Index int `json:"index"` MessageId string `json:"message_id"` MessageHash string `json:"message_hash"` SessionId string `json:"session_id"` @@ -443,15 +233,16 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) { MessageHash: data.MessageHash, }) - wsClient := h.App.ChatClients.Get(data.SessionId) - if wsClient != nil { - content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) - utils.ReplyMessage(wsClient, content) - if h.App.MjTaskClients.Get(data.SessionId) == nil { - h.App.MjTaskClients.Put(data.SessionId, wsClient) + if src == types.TaskSrcChat { + wsClient := h.App.ChatClients.Get(data.SessionId) + if wsClient != nil { + content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) + utils.ReplyMessage(wsClient, content) + if h.mjService.ChatClients.Get(data.SessionId) == nil { + h.mjService.ChatClients.Put(data.SessionId, wsClient) + } } } - resp.SUCCESS(c) } @@ -513,13 +304,15 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) { MessageHash: data.MessageHash, }) - // 从聊天窗口发送的请求,记录客户端信息 - wsClient := h.App.ChatClients.Get(data.SessionId) - if wsClient != nil { - content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) - utils.ReplyMessage(wsClient, content) - if h.App.MjTaskClients.Get(data.SessionId) == nil { - h.App.MjTaskClients.Put(data.SessionId, wsClient) + if src == types.TaskSrcChat { + // 从聊天窗口发送的请求,记录客户端信息 + wsClient := h.mjService.ChatClients.Get(data.SessionId) + if wsClient != nil { + content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt) + utils.ReplyMessage(wsClient, content) + if h.mjService.Clients.Get(data.SessionId) == nil { + h.mjService.Clients.Put(data.SessionId, wsClient) + } } } resp.SUCCESS(c) diff --git a/api/handler/openai_handler.go b/api/handler/openai_handler.go index dd7d7bac..905e78bc 100644 --- a/api/handler/openai_handler.go +++ b/api/handler/openai_handler.go @@ -150,7 +150,7 @@ func (h *ChatHandler) sendOpenAiMessage( content := data if functionName == types.FuncMidJourney { content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data) - h.App.MjTaskClients.Put(session.SessionId, ws) + h.mjService.ChatClients.Put(session.SessionId, ws) // update user's img_calls h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) } diff --git a/api/handler/reward_handler.go b/api/handler/reward_handler.go index 6a8fc831..ed603d30 100644 --- a/api/handler/reward_handler.go +++ b/api/handler/reward_handler.go @@ -22,50 +22,6 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler { return &h } -func (h *RewardHandler) Notify(c *gin.Context) { - token := c.GetHeader("Authorization") - if token != h.App.Config.ExtConfig.Token { - resp.NotAuth(c) - return - } - - var data struct { - TransId string `json:"trans_id"` // 微信转账交易 ID - Amount float64 `json:"amount"` // 微信转账交易金额 - Remark string `json:"remark"` // 转账备注 - } - if err := c.ShouldBindJSON(&data); err != nil { - resp.ERROR(c, types.InvalidArgs) - return - } - - if data.Amount <= 0 { - resp.ERROR(c, "Amount should not be 0") - return - } - - logger.Infof("收到众筹收款信息: %+v", data) - var item model.Reward - res := h.db.Where("tx_id = ?", data.TransId).First(&item) - if res.Error == nil { - resp.ERROR(c, "当前交易 ID 己经存在!") - return - } - - res = h.db.Create(&model.Reward{ - TxId: data.TransId, - Amount: data.Amount, - Remark: data.Remark, - Status: false, - }) - if res.Error != nil { - logger.Errorf("交易保存失败: %v", res.Error) - resp.ERROR(c, "交易保存失败") - return - } - resp.SUCCESS(c) -} - // Verify 打赏码核销 func (h *RewardHandler) Verify(c *gin.Context) { var data struct { diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index fbeca078..ac79c40d 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -1,315 +1,316 @@ package handler -import ( - "chatplus/core" - "chatplus/core/types" - "chatplus/service" - "chatplus/service/oss" - "chatplus/store/model" - "chatplus/store/vo" - "chatplus/utils" - "chatplus/utils/resp" - "encoding/base64" - "fmt" - "github.com/gin-gonic/gin" - "github.com/go-redis/redis/v8" - "github.com/gorilla/websocket" - "gorm.io/gorm" - "net/http" - "strings" - "sync" - "time" -) - -type SdJobHandler struct { - BaseHandler - redis *redis.Client - db *gorm.DB - mjService *service.MjService - uploaderManager *oss.UploaderManager - lock sync.Mutex - clients *types.LMap[string, *types.WsClient] -} - -func NewSdJobHandler( - app *core.AppServer, - client *redis.Client, - db *gorm.DB, - manager *oss.UploaderManager, - mjService *service.MjService) *MidJourneyHandler { - h := MidJourneyHandler{ - redis: client, - db: db, - uploaderManager: manager, - lock: sync.Mutex{}, - mjService: mjService, - clients: types.NewLMap[string, *types.WsClient](), - } - h.App = app - return &h -} - -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *SdJobHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - return - } - - sessionId := c.Query("session_id") - client := types.NewWsClient(ws) - // 删除旧的连接 - h.clients.Delete(sessionId) - h.clients.Put(sessionId, client) - logger.Infof("New websocket connected, IP: %s", c.ClientIP()) -} - -type sdNotifyData struct { - TaskId string - ImageName string - ImageData string - Progress int - Seed string - Success bool - Message string -} - -func (h *SdJobHandler) Notify(c *gin.Context) { - token := c.GetHeader("Authorization") - if token != h.App.Config.ExtConfig.Token { - resp.NotAuth(c) - return - } - var data sdNotifyData - if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" { - resp.ERROR(c, types.InvalidArgs) - return - } - logger.Debugf("收到 MidJourney 回调请求:%+v", data) - - h.lock.Lock() - defer h.lock.Unlock() - - err, finished := h.notifyHandler(c, data) - if err != nil { - resp.ERROR(c, err.Error()) - return - } - - // 解除任务锁定 - if finished && (data.Progress == 100) { - h.redis.Del(c, service.MjRunningJobKey) - } - resp.SUCCESS(c) - -} - -func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) { - taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() - if err != nil { // 过期任务,丢弃 - logger.Warn("任务已过期:", err) - return nil, true - } - - var task types.SdTask - err = utils.JsonDecode(taskString, &task) - if err != nil { // 非标准任务,丢弃 - logger.Warn("任务解析失败:", err) - return nil, false - } - - var job model.SdJob - res := h.db.Where("id = ?", task.Id).First(&job) - if res.Error != nil { - logger.Warn("非法任务:", res.Error) - return nil, false - } - job.Params = utils.JsonEncode(task.Params) - job.ReferenceId = data.ImageData - job.Progress = data.Progress - job.Prompt = data.Prompt - job.Hash = data.Image.Hash - - // 任务完成,将最终的图片下载下来 - if data.Progress == 100 { - imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) - if err != nil { - logger.Error("error with download img: ", err.Error()) - return err, false - } - job.ImgURL = imgURL - } else { - // 临时图片直接保存,访问的时候使用代理进行转发 - job.ImgURL = data.Image.URL - } - res = h.db.Updates(&job) - if res.Error != nil { - logger.Error("error with update job: ", res.Error) - return res.Error, false - } - - var jobVo vo.MidJourneyJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - if data.Progress < 100 { - image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) - if err == nil { - jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } - } - - // 推送任务到前端 - client := h.clients.Get(task.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } - - // 更新用户剩余绘图次数 - if data.Progress == 100 { - h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) - } - - return nil, true -} - -func (h *SdJobHandler) checkLimits(c *gin.Context) bool { - user, err := utils.GetLoginUser(c, h.db) - if err != nil { - resp.NotAuth(c) - return false - } - - if user.ImgCalls <= 0 { - resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") - return false - } - - return true - -} - -// Image 创建一个绘画任务 -func (h *SdJobHandler) Image(c *gin.Context) { - var data struct { - SessionId string `json:"session_id"` - Prompt string `json:"prompt"` - Rate string `json:"rate"` - Model string `json:"model"` - Chaos int `json:"chaos"` - Raw bool `json:"raw"` - Seed int64 `json:"seed"` - Stylize int `json:"stylize"` - Img string `json:"img"` - Weight float32 `json:"weight"` - } - if err := c.ShouldBindJSON(&data); err != nil { - resp.ERROR(c, types.InvalidArgs) - return - } - if !h.checkLimits(c) { - return - } - - var prompt = data.Prompt - if data.Rate != "" && !strings.Contains(prompt, "--ar") { - prompt += " --ar " + data.Rate - } - if data.Seed > 0 && !strings.Contains(prompt, "--seed") { - prompt += fmt.Sprintf(" --seed %d", data.Seed) - } - if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { - prompt += fmt.Sprintf(" --s %d", data.Stylize) - } - if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { - prompt += fmt.Sprintf(" --c %d", data.Chaos) - } - if data.Img != "" { - prompt = fmt.Sprintf("%s %s", data.Img, prompt) - if data.Weight > 0 { - prompt += fmt.Sprintf(" --iw %f", data.Weight) - } - } - if data.Raw { - prompt += " --style raw" - } - if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { - prompt += data.Model - } - - idValue, _ := c.Get(types.LoginUserID) - userId := utils.IntValue(utils.InterfaceToString(idValue), 0) - job := model.MidJourneyJob{ - Type: service.Image.String(), - UserId: userId, - Progress: 0, - Prompt: prompt, - CreatedAt: time.Now(), - } - if res := h.db.Create(&job); res.Error != nil { - resp.ERROR(c, "添加任务失败:"+res.Error.Error()) - return - } - - h.mjService.PushTask(service.MjTask{ - Id: int(job.Id), - SessionId: data.SessionId, - Src: service.TaskSrcImg, - Type: service.Image, - Prompt: prompt, - UserId: userId, - }) - - var jobVo vo.MidJourneyJob - err := utils.CopyObject(job, &jobVo) - if err == nil { - // 推送任务到前端 - client := h.clients.Get(data.SessionId) - if client != nil { - utils.ReplyChunkMessage(client, jobVo) - } - } - resp.SUCCESS(c) -} - -// JobList 获取 MJ 任务列表 -func (h *SdJobHandler) JobList(c *gin.Context) { - status := h.GetInt(c, "status", 0) - var items []model.MidJourneyJob - var res *gorm.DB - userId, _ := c.Get(types.LoginUserID) - if status == 1 { - res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) - } else { - res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) - } - if res.Error != nil { - resp.ERROR(c, types.NoData) - return - } - - var jobs = make([]vo.MidJourneyJob, 0) - for _, item := range items { - var job vo.MidJourneyJob - err := utils.CopyObject(item, &job) - if err != nil { - continue - } - if item.Progress < 100 { - // 30 分钟还没完成的任务直接删除 - if time.Now().Sub(item.CreatedAt) > time.Minute*30 { - h.db.Delete(&item) - continue - } - if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 - image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) - if err == nil { - job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) - } - } - } - jobs = append(jobs, job) - } - resp.SUCCESS(c, jobs) -} +// +//import ( +// "chatplus/core" +// "chatplus/core/types" +// "chatplus/service" +// "chatplus/service/oss" +// "chatplus/store/model" +// "chatplus/store/vo" +// "chatplus/utils" +// "chatplus/utils/resp" +// "encoding/base64" +// "fmt" +// "github.com/gin-gonic/gin" +// "github.com/go-redis/redis/v8" +// "github.com/gorilla/websocket" +// "gorm.io/gorm" +// "net/http" +// "strings" +// "sync" +// "time" +//) +// +//type SdJobHandler struct { +// BaseHandler +// redis *redis.Client +// db *gorm.DB +// mjService *service.MjService +// uploaderManager *oss.UploaderManager +// lock sync.Mutex +// clients *types.LMap[string, *types.WsClient] +//} +// +//func NewSdJobHandler( +// app *core.AppServer, +// client *redis.Client, +// db *gorm.DB, +// manager *oss.UploaderManager, +// mjService *service.MjService) *MidJourneyHandler { +// h := MidJourneyHandler{ +// redis: client, +// db: db, +// uploaderManager: manager, +// lock: sync.Mutex{}, +// mjService: mjService, +// clients: types.NewLMap[string, *types.WsClient](), +// } +// h.App = app +// return &h +//} +// +//// Client WebSocket 客户端,用于通知任务状态变更 +//func (h *SdJobHandler) Client(c *gin.Context) { +// ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) +// if err != nil { +// logger.Error(err) +// return +// } +// +// sessionId := c.Query("session_id") +// client := types.NewWsClient(ws) +// // 删除旧的连接 +// h.clients.Delete(sessionId) +// h.clients.Put(sessionId, client) +// logger.Infof("New websocket connected, IP: %s", c.ClientIP()) +//} +// +//type sdNotifyData struct { +// TaskId string +// ImageName string +// ImageData string +// Progress int +// Seed string +// Success bool +// Message string +//} +// +//func (h *SdJobHandler) Notify(c *gin.Context) { +// token := c.GetHeader("Authorization") +// if token != h.App.Config.ExtConfig.Token { +// resp.NotAuth(c) +// return +// } +// var data sdNotifyData +// if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" { +// resp.ERROR(c, types.InvalidArgs) +// return +// } +// logger.Debugf("收到 MidJourney 回调请求:%+v", data) +// +// h.lock.Lock() +// defer h.lock.Unlock() +// +// err, finished := h.notifyHandler(c, data) +// if err != nil { +// resp.ERROR(c, err.Error()) +// return +// } +// +// // 解除任务锁定 +// if finished && (data.Progress == 100) { +// h.redis.Del(c, service.MjRunningJobKey) +// } +// resp.SUCCESS(c) +// +//} +// +//func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) { +// taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result() +// if err != nil { // 过期任务,丢弃 +// logger.Warn("任务已过期:", err) +// return nil, true +// } +// +// var task types.SdTask +// err = utils.JsonDecode(taskString, &task) +// if err != nil { // 非标准任务,丢弃 +// logger.Warn("任务解析失败:", err) +// return nil, false +// } +// +// var job model.SdJob +// res := h.db.Where("id = ?", task.Id).First(&job) +// if res.Error != nil { +// logger.Warn("非法任务:", res.Error) +// return nil, false +// } +// job.Params = utils.JsonEncode(task.Params) +// job.ReferenceId = data.ImageData +// job.Progress = data.Progress +// job.Prompt = data.Prompt +// job.Hash = data.Image.Hash +// +// // 任务完成,将最终的图片下载下来 +// if data.Progress == 100 { +// imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL) +// if err != nil { +// logger.Error("error with download img: ", err.Error()) +// return err, false +// } +// job.ImgURL = imgURL +// } else { +// // 临时图片直接保存,访问的时候使用代理进行转发 +// job.ImgURL = data.Image.URL +// } +// res = h.db.Updates(&job) +// if res.Error != nil { +// logger.Error("error with update job: ", res.Error) +// return res.Error, false +// } +// +// var jobVo vo.MidJourneyJob +// err := utils.CopyObject(job, &jobVo) +// if err == nil { +// if data.Progress < 100 { +// image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL) +// if err == nil { +// jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) +// } +// } +// +// // 推送任务到前端 +// client := h.clients.Get(task.SessionId) +// if client != nil { +// utils.ReplyChunkMessage(client, jobVo) +// } +// } +// +// // 更新用户剩余绘图次数 +// if data.Progress == 100 { +// h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) +// } +// +// return nil, true +//} +// +//func (h *SdJobHandler) checkLimits(c *gin.Context) bool { +// user, err := utils.GetLoginUser(c, h.db) +// if err != nil { +// resp.NotAuth(c) +// return false +// } +// +// if user.ImgCalls <= 0 { +// resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!") +// return false +// } +// +// return true +// +//} +// +//// Image 创建一个绘画任务 +//func (h *SdJobHandler) Image(c *gin.Context) { +// var data struct { +// SessionId string `json:"session_id"` +// Prompt string `json:"prompt"` +// Rate string `json:"rate"` +// Model string `json:"model"` +// Chaos int `json:"chaos"` +// Raw bool `json:"raw"` +// Seed int64 `json:"seed"` +// Stylize int `json:"stylize"` +// Img string `json:"img"` +// Weight float32 `json:"weight"` +// } +// if err := c.ShouldBindJSON(&data); err != nil { +// resp.ERROR(c, types.InvalidArgs) +// return +// } +// if !h.checkLimits(c) { +// return +// } +// +// var prompt = data.Prompt +// if data.Rate != "" && !strings.Contains(prompt, "--ar") { +// prompt += " --ar " + data.Rate +// } +// if data.Seed > 0 && !strings.Contains(prompt, "--seed") { +// prompt += fmt.Sprintf(" --seed %d", data.Seed) +// } +// if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") { +// prompt += fmt.Sprintf(" --s %d", data.Stylize) +// } +// if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") { +// prompt += fmt.Sprintf(" --c %d", data.Chaos) +// } +// if data.Img != "" { +// prompt = fmt.Sprintf("%s %s", data.Img, prompt) +// if data.Weight > 0 { +// prompt += fmt.Sprintf(" --iw %f", data.Weight) +// } +// } +// if data.Raw { +// prompt += " --style raw" +// } +// if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") { +// prompt += data.Model +// } +// +// idValue, _ := c.Get(types.LoginUserID) +// userId := utils.IntValue(utils.InterfaceToString(idValue), 0) +// job := model.MidJourneyJob{ +// Type: service.Image.String(), +// UserId: userId, +// Progress: 0, +// Prompt: prompt, +// CreatedAt: time.Now(), +// } +// if res := h.db.Create(&job); res.Error != nil { +// resp.ERROR(c, "添加任务失败:"+res.Error.Error()) +// return +// } +// +// h.mjService.PushTask(service.MjTask{ +// Id: int(job.Id), +// SessionId: data.SessionId, +// Src: service.TaskSrcImg, +// Type: service.Image, +// Prompt: prompt, +// UserId: userId, +// }) +// +// var jobVo vo.MidJourneyJob +// err := utils.CopyObject(job, &jobVo) +// if err == nil { +// // 推送任务到前端 +// client := h.clients.Get(data.SessionId) +// if client != nil { +// utils.ReplyChunkMessage(client, jobVo) +// } +// } +// resp.SUCCESS(c) +//} +// +//// JobList 获取 MJ 任务列表 +//func (h *SdJobHandler) JobList(c *gin.Context) { +// status := h.GetInt(c, "status", 0) +// var items []model.MidJourneyJob +// var res *gorm.DB +// userId, _ := c.Get(types.LoginUserID) +// if status == 1 { +// res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items) +// } else { +// res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items) +// } +// if res.Error != nil { +// resp.ERROR(c, types.NoData) +// return +// } +// +// var jobs = make([]vo.MidJourneyJob, 0) +// for _, item := range items { +// var job vo.MidJourneyJob +// err := utils.CopyObject(item, &job) +// if err != nil { +// continue +// } +// if item.Progress < 100 { +// // 30 分钟还没完成的任务直接删除 +// if time.Now().Sub(item.CreatedAt) > time.Minute*30 { +// h.db.Delete(&item) +// continue +// } +// if item.ImgURL != "" { // 正在运行中任务使用代理访问图片 +// image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL) +// if err == nil { +// job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) +// } +// } +// } +// jobs = append(jobs, job) +// } +// resp.SUCCESS(c, jobs) +//} diff --git a/api/main.go b/api/main.go index eb6feb77..d30a5b10 100644 --- a/api/main.go +++ b/api/main.go @@ -7,8 +7,10 @@ import ( "chatplus/handler/admin" logger2 "chatplus/logger" "chatplus/service" - "chatplus/service/function" + "chatplus/service/fun" + "chatplus/service/mj" "chatplus/service/oss" + "chatplus/service/wx" "chatplus/store" "context" "embed" @@ -107,7 +109,7 @@ func main() { }), // 创建函数 - fx.Provide(function.NewFunctions), + fx.Provide(fun.NewFunctions), // 创建控制器 fx.Provide(handler.NewChatRoleHandler), @@ -135,13 +137,36 @@ func main() { return service.NewCaptchaService(config.ApiConfig) }), fx.Provide(oss.NewUploaderManager), - fx.Provide(service.NewMjService), - fx.Invoke(func(mjService *service.MjService) { + fx.Provide(mj.NewService), + fx.Invoke(func(mjService *mj.Service) { go func() { mjService.Run() }() }), + // 微信机器人服务 + fx.Provide(wx.NewWeChatBot), + fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) { + if config.WeChatBot { + err := bot.Run() + if err != nil { + logger.Error("微信登录失败:", err) + } + } + }), + + // MidJourney 机器人 + fx.Provide(mj.NewBot), + fx.Provide(mj.NewClient), + fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) { + if config.MjConfig.Enabled { + err := bot.Run() + if err != nil { + log.Fatal("MidJourney 服务启动失败:", err) + } + } + }), + // 注册路由 fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { group := s.Engine.Group("/api/role/") @@ -185,12 +210,10 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { group := s.Engine.Group("/api/reward/") - group.POST("notify", h.Notify) group.POST("verify", h.Verify) }), fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { group := s.Engine.Group("/api/mj/") - group.POST("notify", h.Notify) group.POST("image", h.Image) group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) diff --git a/api/service/function/func_mj.go b/api/service/fun/func_mj.go similarity index 77% rename from api/service/function/func_mj.go rename to api/service/fun/func_mj.go index 4ecf0f1a..0b5899db 100644 --- a/api/service/function/func_mj.go +++ b/api/service/fun/func_mj.go @@ -1,7 +1,8 @@ -package function +package fun import ( - "chatplus/service" + "chatplus/core/types" + "chatplus/service/mj" "chatplus/utils" ) @@ -9,10 +10,10 @@ import ( type FuncMidJourney struct { name string - service *service.MjService + service *mj.Service } -func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney { +func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney { return FuncMidJourney{ name: "MidJourney AI 绘画", service: mjService} @@ -21,10 +22,10 @@ func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney { func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) { logger.Infof("MJ 绘画参数:%+v", params) prompt := utils.InterfaceToString(params["prompt"]) - f.service.PushTask(service.MjTask{ + f.service.PushTask(types.MjTask{ SessionId: utils.InterfaceToString(params["session_id"]), - Src: service.TaskSrcChat, - Type: service.Image, + Src: types.TaskSrcChat, + Type: types.TaskImage, Prompt: prompt, UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0), RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0), diff --git a/api/service/function/function.go b/api/service/fun/function.go similarity index 85% rename from api/service/function/function.go rename to api/service/fun/function.go index 5b558feb..76e523bd 100644 --- a/api/service/function/function.go +++ b/api/service/fun/function.go @@ -1,9 +1,9 @@ -package function +package fun import ( "chatplus/core/types" logger2 "chatplus/logger" - "chatplus/service" + "chatplus/service/mj" ) type Function interface { @@ -29,7 +29,7 @@ type dataItem struct { Remark string `json:"remark"` } -func NewFunctions(config *types.AppConfig, mjService *service.MjService) map[string]Function { +func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function { return map[string]Function{ types.FuncZaoBao: NewZaoBao(config.ApiConfig), types.FuncWeibo: NewWeiboHot(config.ApiConfig), diff --git a/api/service/function/tou_tiao.go b/api/service/fun/tou_tiao.go similarity index 98% rename from api/service/function/tou_tiao.go rename to api/service/fun/tou_tiao.go index c77e2141..7cdc1f11 100644 --- a/api/service/function/tou_tiao.go +++ b/api/service/fun/tou_tiao.go @@ -1,4 +1,4 @@ -package function +package fun import ( "chatplus/core/types" diff --git a/api/service/function/weibo_hot.go b/api/service/fun/weibo_hot.go similarity index 98% rename from api/service/function/weibo_hot.go rename to api/service/fun/weibo_hot.go index 95fccc27..fac19684 100644 --- a/api/service/function/weibo_hot.go +++ b/api/service/fun/weibo_hot.go @@ -1,4 +1,4 @@ -package function +package fun import ( "chatplus/core/types" diff --git a/api/service/function/zao_bao.go b/api/service/fun/zao_bao.go similarity index 98% rename from api/service/function/zao_bao.go rename to api/service/fun/zao_bao.go index 174c81c4..9c507faf 100644 --- a/api/service/function/zao_bao.go +++ b/api/service/fun/zao_bao.go @@ -1,4 +1,4 @@ -package function +package fun import ( "chatplus/core/types" diff --git a/api/service/mj/bot.go b/api/service/mj/bot.go new file mode 100644 index 00000000..4463901e --- /dev/null +++ b/api/service/mj/bot.go @@ -0,0 +1,213 @@ +package mj + +import ( + "chatplus/core/types" + logger2 "chatplus/logger" + "chatplus/utils" + "github.com/bwmarrin/discordgo" + "github.com/gorilla/websocket" + "net/http" + "net/url" + "regexp" + "strings" +) + +// MidJourney 机器人 + +var logger = logger2.GetLogger() + +type Bot struct { + config *types.MidJourneyConfig + bot *discordgo.Session + service *Service +} + +func NewBot(config *types.AppConfig, service *Service) (*Bot, error) { + discord, err := discordgo.New("Bot " + config.MjConfig.BotToken) + if err != nil { + return nil, err + } + + if config.ProxyURL != "" { + proxy, _ := url.Parse(config.ProxyURL) + discord.Client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + discord.Dialer = &websocket.Dialer{ + Proxy: http.ProxyURL(proxy), + } + } + + return &Bot{ + config: &config.MjConfig, + bot: discord, + service: service, + }, nil +} + +func (b *Bot) Run() error { + b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent + b.bot.AddHandler(b.messageCreate) + b.bot.AddHandler(b.messageUpdate) + + logger.Info("Starting MidJourney Bot...") + err := b.bot.Open() + if err != nil { + logger.Error("Error opening Discord connection:", err) + return err + } + logger.Info("Starting MidJourney Bot successfully!") + return nil +} + +type TaskStatus string + +const ( + Start = TaskStatus("Started") + Running = TaskStatus("Running") + Stopped = TaskStatus("Stopped") + Finished = TaskStatus("Finished") +) + +type Image struct { + URL string `json:"url"` + ProxyURL string `json:"proxy_url"` + Filename string `json:"filename"` + Width int `json:"width"` + Height int `json:"height"` + Size int `json:"size"` + Hash string `json:"hash"` +} + +func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { + // ignore messages for other channels + if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId { + return + } + // ignore messages for self + if m.Author.ID == s.State.User.ID { + return + } + + logger.Debugf("CREATE: %s", utils.JsonEncode(m)) + var referenceId = "" + if m.ReferencedMessage != nil { + referenceId = m.ReferencedMessage.ID + } + if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") { + // parse content + req := CBReq{ + MessageId: m.ID, + ReferenceId: referenceId, + Prompt: extractPrompt(m.Content), + Content: m.Content, + Progress: 0, + Status: Start} + b.service.Notify(req) + return + } + + b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) +} + +func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) { + // ignore messages for other channels + if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId { + return + } + // ignore messages for self + if m.Author.ID == s.State.User.ID { + return + } + + logger.Debugf("UPDATE: %s", utils.JsonEncode(m)) + + var referenceId = "" + if m.ReferencedMessage != nil { + referenceId = m.ReferencedMessage.ID + } + if strings.Contains(m.Content, "(Stopped)") { + req := CBReq{ + MessageId: m.ID, + ReferenceId: referenceId, + Prompt: extractPrompt(m.Content), + Content: m.Content, + Progress: extractProgress(m.Content), + Status: Stopped} + b.service.Notify(req) + return + } + + b.addAttachment(m.ID, referenceId, m.Content, m.Attachments) + +} + +func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) { + progress := extractProgress(content) + var status TaskStatus + if progress == 100 { + status = Finished + } else { + status = Running + } + for _, attachment := range attachments { + if attachment.Width == 0 || attachment.Height == 0 { + continue + } + image := Image{ + URL: attachment.URL, + Height: attachment.Height, + ProxyURL: attachment.ProxyURL, + Width: attachment.Width, + Size: attachment.Size, + Filename: attachment.Filename, + Hash: extractHashFromFilename(attachment.Filename), + } + req := CBReq{ + MessageId: messageId, + ReferenceId: referenceId, + Image: image, + Prompt: extractPrompt(content), + Content: content, + Progress: progress, + Status: status, + } + b.service.Notify(req) + break // only get one image + } +} + +// extract prompt from string +func extractPrompt(input string) string { + pattern := `\*\*(.*?)\*\*` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(input) + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +func extractProgress(input string) int { + pattern := `\((\d+)\%\)` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(input) + if len(matches) > 1 { + return utils.IntValue(matches[1], 0) + } + return 100 +} + +func extractHashFromFilename(filename string) string { + if !strings.HasSuffix(filename, ".png") { + return "" + } + + index := strings.LastIndex(filename, "_") + if index != -1 { + return filename[index+1 : len(filename)-4] + } + return "" +} diff --git a/api/service/mj/client.go b/api/service/mj/client.go new file mode 100644 index 00000000..844c3c6a --- /dev/null +++ b/api/service/mj/client.go @@ -0,0 +1,144 @@ +package mj + +import ( + "chatplus/core/types" + "fmt" + "github.com/imroc/req/v3" + "time" +) + +// MidJourney client + +type Client struct { + client *req.Client + config *types.MidJourneyConfig +} + +func NewClient(config *types.AppConfig) *Client { + client := req.C().SetTimeout(10 * time.Second) + // set proxy URL + if config.ProxyURL != "" { + client.SetProxyURL(config.ProxyURL) + } + return &Client{client: client, config: &config.MjConfig} +} + +func (c *Client) Imagine(prompt string) error { + interactionsReq := &InteractionsRequest{ + Type: 2, + ApplicationID: ApplicationID, + GuildID: c.config.GuildId, + ChannelID: c.config.ChanelId, + SessionID: SessionID, + Data: map[string]any{ + "version": "1118961510123847772", + "id": "938956540159881230", + "name": "imagine", + "type": "1", + "options": []map[string]any{ + { + "type": 3, + "name": "prompt", + "value": prompt, + }, + }, + "application_command": map[string]any{ + "id": "938956540159881230", + "application_id": ApplicationID, + "version": "1118961510123847772", + "default_permission": true, + "default_member_permissions": nil, + "type": 1, + "nsfw": false, + "name": "imagine", + "description": "Create images with Midjourney", + "dm_permission": true, + "options": []map[string]any{ + { + "type": 3, + "name": "prompt", + "description": "The prompt to imagine", + "required": true, + }, + }, + "attachments": []any{}, + }, + }, + } + + url := "https://discord.com/api/v9/interactions" + r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). + SetHeader("Content-Type", "application/json"). + SetBody(interactionsReq). + Post(url) + + if err != nil || r.IsErrorState() { + return fmt.Errorf("error with http request: %w%v", err, r.Err) + } + + return nil +} + +// Upscale 放大指定的图片 +func (c *Client) Upscale(index int, messageId string, hash string) error { + flags := 0 + interactionsReq := &InteractionsRequest{ + Type: 3, + ApplicationID: ApplicationID, + GuildID: c.config.GuildId, + ChannelID: c.config.ChanelId, + MessageFlags: &flags, + MessageID: &messageId, + SessionID: SessionID, + Data: map[string]any{ + "component_type": 2, + "custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash), + }, + Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), + } + + url := "https://discord.com/api/v9/interactions" + var res InteractionsResult + r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). + SetHeader("Content-Type", "application/json"). + SetBody(interactionsReq). + SetErrorResult(&res). + Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) + } + + return nil +} + +// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 +func (c *Client) Variation(index int, messageId string, hash string) error { + flags := 0 + interactionsReq := &InteractionsRequest{ + Type: 3, + ApplicationID: ApplicationID, + GuildID: c.config.GuildId, + ChannelID: c.config.ChanelId, + MessageFlags: &flags, + MessageID: &messageId, + SessionID: SessionID, + Data: map[string]any{ + "component_type": 2, + "custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash), + }, + Nonce: fmt.Sprintf("%d", time.Now().UnixNano()), + } + + url := "https://discord.com/api/v9/interactions" + var res InteractionsResult + r, err := c.client.R().SetHeader("Authorization", c.config.UserToken). + SetHeader("Content-Type", "application/json"). + SetBody(interactionsReq). + SetErrorResult(&res). + Post(url) + if err != nil || r.IsErrorState() { + return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message) + } + + return nil +} diff --git a/api/service/mj/service.go b/api/service/mj/service.go new file mode 100644 index 00000000..25c147bf --- /dev/null +++ b/api/service/mj/service.go @@ -0,0 +1,249 @@ +package mj + +import ( + "chatplus/core/types" + "chatplus/service/oss" + "chatplus/store" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "context" + "encoding/base64" + "fmt" + "github.com/go-redis/redis/v8" + "gorm.io/gorm" + "time" +) + +// MJ 绘画服务 + +const RunningJobKey = "MidJourney_Running_Job" + +type Service struct { + client *Client + taskQueue *store.RedisQueue + redis *redis.Client + db *gorm.DB + uploadManager *oss.UploaderManager + Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息 + ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息 + proxyURL string +} + +func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service { + return &Service{ + redis: redisCli, + db: db, + taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli), + client: client, + uploadManager: manager, + Clients: types.NewLMap[string, *types.WsClient](), + ChatClients: types.NewLMap[string, *types.WsClient](), + proxyURL: config.ProxyURL, + } +} + +func (s *Service) Run() { + logger.Info("Starting MidJourney job consumer.") + ctx := context.Background() + for { + _, err := s.redis.Get(ctx, RunningJobKey).Result() + if err == nil { // 队列串行执行 + time.Sleep(time.Second * 3) + continue + } + var task types.MjTask + err = s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + logger.Infof("Consuming Task: %+v", task) + switch task.Type { + case types.TaskImage: + err = s.client.Imagine(task.Prompt) + break + case types.TaskUpscale: + err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash) + + break + case types.TaskVariation: + err = s.client.Variation(task.Index, task.MessageId, task.MessageHash) + } + if err != nil { + logger.Error("绘画任务执行失败:", err) + if task.RetryCount <= 5 { + s.taskQueue.RPush(task) + } + task.RetryCount += 1 + time.Sleep(time.Second * 3) + continue + } + + // 更新任务的执行状态 + s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) + // 锁定任务执行通道,直到任务超时(5分钟) + s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5) + } +} + +func (s *Service) PushTask(task types.MjTask) { + logger.Infof("add a new MidJourney Task: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Notify(data CBReq) { + taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result() + if err != nil { // 过期任务,丢弃 + logger.Warn("任务已过期:", err) + return + } + + var task types.MjTask + err = utils.JsonDecode(taskString, &task) + if err != nil { // 非标准任务,丢弃 + logger.Warn("任务解析失败:", err) + return + } + + var job model.MidJourneyJob + res := s.db.Where("message_id = ?", data.MessageId).First(&job) + if res.Error == nil && data.Status == Finished { + logger.Warn("重复消息:", data.MessageId) + return + } + + if task.Src == types.TaskSrcImg { // 绘画任务 + var job model.MidJourneyJob + res := s.db.Where("id = ?", task.Id).First(&job) + if res.Error != nil { + logger.Warn("非法任务:", res.Error) + return + } + job.MessageId = data.MessageId + job.ReferenceId = data.ReferenceId + job.Progress = data.Progress + job.Prompt = data.Prompt + job.Hash = data.Image.Hash + + // 任务完成,将最终的图片下载下来 + if data.Progress == 100 { + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL) + if err != nil { + logger.Error("error with download img: ", err.Error()) + return + } + job.ImgURL = imgURL + } else { + // 临时图片直接保存,访问的时候使用代理进行转发 + job.ImgURL = data.Image.URL + } + res = s.db.Updates(&job) + if res.Error != nil { + logger.Error("error with update job: ", res.Error) + return + } + + var jobVo vo.MidJourneyJob + err := utils.CopyObject(job, &jobVo) + if err == nil { + if data.Progress < 100 { + image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL) + if err == nil { + jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } + + // 推送任务到前端 + client := s.Clients.Get(task.SessionId) + if client != nil { + utils.ReplyChunkMessage(client, jobVo) + } + } + + } else if task.Src == types.TaskSrcChat { // 聊天任务 + wsClient := s.ChatClients.Get(task.SessionId) + if data.Status == Finished { + if wsClient != nil && data.ReferenceId != "" { + content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt) + utils.ReplyMessage(wsClient, content) + } + // download image + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL) + if err != nil { + logger.Error("error with download image: ", err) + if wsClient != nil && data.ReferenceId != "" { + content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error()) + utils.ReplyMessage(wsClient, content) + } + return + } + + tx := s.db.Begin() + data.Image.URL = imgURL + message := model.HistoryMessage{ + UserId: uint(task.UserId), + ChatId: task.ChatId, + RoleId: uint(task.RoleId), + Type: types.MjMsg, + Icon: task.Icon, + Content: utils.JsonEncode(data), + Tokens: 0, + UseContext: false, + } + res = tx.Create(&message) + if res.Error != nil { + logger.Error("error with update database: ", err) + return + } + + // save the job + job.UserId = task.UserId + job.Type = task.Type.String() + job.MessageId = data.MessageId + job.ReferenceId = data.ReferenceId + job.Prompt = data.Prompt + job.ImgURL = imgURL + job.Progress = data.Progress + job.Hash = data.Image.Hash + job.CreatedAt = time.Now() + res = tx.Create(&job) + if res.Error != nil { + logger.Error("error with update database: ", err) + tx.Rollback() + return + } + tx.Commit() + } + + if wsClient == nil { // 客户端断线,则丢弃 + logger.Errorf("Client is offline: %+v", data) + return + } + + if data.Status == Finished { + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd}) + // 本次绘画完毕,移除客户端 + s.ChatClients.Delete(task.SessionId) + } else { + // 使用代理临时转发图片 + if data.Image.URL != "" { + image, err := utils.DownloadImage(data.Image.URL, s.proxyURL) + if err == nil { + data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } + utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data}) + } + } + + // 更新用户剩余绘图次数 + // TODO: 放大图片是否需要消耗绘图次数? + if data.Status == Finished { + s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1)) + // 解除任务锁定 + s.redis.Del(context.Background(), RunningJobKey) + } + +} diff --git a/api/service/mj/types.go b/api/service/mj/types.go new file mode 100644 index 00000000..7de8c0c1 --- /dev/null +++ b/api/service/mj/types.go @@ -0,0 +1,34 @@ +package mj + +const ( + ApplicationID string = "936929561302675456" + SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe" +) + +type InteractionsRequest struct { + Type int `json:"type"` + ApplicationID string `json:"application_id"` + MessageFlags *int `json:"message_flags,omitempty"` + MessageID *string `json:"message_id,omitempty"` + GuildID string `json:"guild_id"` + ChannelID string `json:"channel_id"` + SessionID string `json:"session_id"` + Data map[string]any `json:"data"` + Nonce string `json:"nonce,omitempty"` +} + +type InteractionsResult struct { + Code int `json:"code"` + Message string + Error map[string]any +} + +type CBReq struct { + MessageId string `json:"message_id"` + ReferenceId string `json:"reference_id"` + Image Image `json:"image"` + Content string `json:"content"` + Prompt string `json:"prompt"` + Status TaskStatus `json:"status"` + Progress int `json:"progress"` +} diff --git a/api/service/mj_service.go b/api/service/mj_service.go deleted file mode 100644 index 393e8e91..00000000 --- a/api/service/mj_service.go +++ /dev/null @@ -1,166 +0,0 @@ -package service - -import ( - "chatplus/core/types" - logger2 "chatplus/logger" - "chatplus/store" - "chatplus/store/model" - "chatplus/utils" - "context" - "errors" - "fmt" - "github.com/go-redis/redis/v8" - "github.com/imroc/req/v3" - "gorm.io/gorm" - "time" -) - -var logger = logger2.GetLogger() - -// MJ 绘画服务 - -const MjRunningJobKey = "MidJourney_Running_Job" - -type MjService struct { - config types.ChatPlusExtConfig - client *req.Client - taskQueue *store.RedisQueue - redis *redis.Client - db *gorm.DB -} - -func NewMjService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *MjService { - return &MjService{ - config: appConfig.ExtConfig, - redis: client, - db: db, - taskQueue: store.NewRedisQueue("midjourney_task_queue", client), - client: req.C().SetTimeout(30 * time.Second)} -} - -func (s *MjService) Run() { - logger.Info("Starting MidJourney job consumer.") - ctx := context.Background() - for { - _, err := s.redis.Get(ctx, MjRunningJobKey).Result() - if err == nil { // 队列串行执行 - time.Sleep(time.Second * 3) - continue - } - var task types.MjTask - err = s.taskQueue.LPop(&task) - if err != nil { - logger.Errorf("taking task with error: %v", err) - continue - } - logger.Infof("Consuming Task: %+v", task) - switch task.Type { - case types.TaskImage: - err = s.image(task.Prompt) - break - case types.TaskUpscale: - err = s.upscale(MjUpscaleReq{ - Index: task.Index, - MessageId: task.MessageId, - MessageHash: task.MessageHash, - }) - break - case types.TaskVariation: - err = s.variation(MjVariationReq{ - Index: task.Index, - MessageId: task.MessageId, - MessageHash: task.MessageHash, - }) - } - if err != nil { - logger.Error("绘画任务执行失败:", err) - if task.RetryCount <= 5 { - s.taskQueue.RPush(task) - } - task.RetryCount += 1 - time.Sleep(time.Second * 3) - continue - } - - // 更新任务的执行状态 - s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) - // 锁定任务执行通道,直到任务超时(5分钟) - s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5) - } -} - -func (s *MjService) PushTask(task types.MjTask) { - logger.Infof("add a new MidJourney Task: %+v", task) - s.taskQueue.RPush(task) -} - -func (s *MjService) image(prompt string) error { - logger.Infof("MJ 绘画参数:%+v", prompt) - body := map[string]string{"prompt": prompt} - url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL) - var res types.BizVo - r, err := s.client.R(). - SetHeader("Authorization", s.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(body). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return errors.New(res.Message) - } - - return nil -} - -type MjUpscaleReq struct { - Index int32 `json:"index"` - MessageId string `json:"message_id"` - MessageHash string `json:"message_hash"` -} - -func (s *MjService) upscale(upReq MjUpscaleReq) error { - url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL) - var res types.BizVo - r, err := s.client.R(). - SetHeader("Authorization", s.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(upReq). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return errors.New(res.Message) - } - - return nil -} - -type MjVariationReq struct { - Index int32 `json:"index"` - MessageId string `json:"message_id"` - MessageHash string `json:"message_hash"` -} - -func (s *MjService) variation(upReq MjVariationReq) error { - url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL) - var res types.BizVo - r, err := s.client.R(). - SetHeader("Authorization", s.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(upReq). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return errors.New(res.Message) - } - - return nil -} diff --git a/api/service/sd/client.go b/api/service/sd/client.go new file mode 100644 index 00000000..c2abe021 --- /dev/null +++ b/api/service/sd/client.go @@ -0,0 +1,169 @@ +package sd + +import ( + "chatplus/core/types" + "chatplus/utils" + "fmt" + "github.com/imroc/req/v3" + "io" + "time" +) + +type Client struct { + httpClient *req.Client + config *types.StableDiffusionConfig +} + +func NewSdClient(config *types.AppConfig) *Client { + return &Client{ + config: &config.SdConfig, + httpClient: req.C(), + } +} + +func (c *Client) Txt2Img(params types.SdTaskParams) error { + var data []interface{} + err := utils.JsonDecode(Text2ImgParamTemplate, &data) + if err != nil { + return err + } + data[ParamKeys["task_id"]] = params.TaskId + data[ParamKeys["prompt"]] = params.Prompt + data[ParamKeys["negative_prompt"]] = params.NegativePrompt + data[ParamKeys["steps"]] = params.Steps + data[ParamKeys["sampler"]] = params.Sampler + data[ParamKeys["face_fix"]] = params.FaceFix + data[ParamKeys["cfg_scale"]] = params.CfgScale + data[ParamKeys["seed"]] = params.Seed + data[ParamKeys["height"]] = params.Height + data[ParamKeys["width"]] = params.Width + data[ParamKeys["hd_fix"]] = params.HdFix + data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate + data[ParamKeys["hd_scale"]] = params.HdScale + data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg + data[ParamKeys["hd_sample_num"]] = params.HdSampleNum + task := TaskInfo{ + TaskId: params.TaskId, + Data: data, + EventData: nil, + FnIndex: 494, + SessionHash: "ycaxgzm9ah", + } + + go func() { + c.runTask(task, c.httpClient) + }() + return nil +} + +func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) { + body := map[string]any{ + "data": taskInfo.Data, + "event_data": taskInfo.EventData, + "fn_index": taskInfo.FnIndex, + "session_hash": taskInfo.SessionHash, + } + + var result = make(chan CBReq) + go func() { + var res struct { + Data []interface{} `json:"data"` + IsGenerating bool `json:"is_generating"` + Duration float64 `json:"duration"` + AverageDuration float64 `json:"average_duration"` + } + var cbReq = CBReq{TaskId: taskInfo.TaskId} + response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict") + if err != nil { + cbReq.Message = "error with send request: " + err.Error() + cbReq.Success = false + result <- cbReq + return + } + + if response.IsErrorState() { + bytes, _ := io.ReadAll(response.Body) + cbReq.Message = "error http status code: " + string(bytes) + cbReq.Success = false + result <- cbReq + return + } + + var images []struct { + Name string `json:"name"` + Data interface{} `json:"data"` + IsFile bool `json:"is_file"` + } + err = utils.ForceCovert(res.Data[0], &images) + if err != nil { + cbReq.Message = "error with decode image:" + err.Error() + cbReq.Success = false + result <- cbReq + return + } + + var info map[string]any + err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info) + if err != nil { + cbReq.Message = err.Error() + cbReq.Success = false + result <- cbReq + return + } + + //for k, v := range info { + // fmt.Println(k, " => ", v) + //} + cbReq.ImageName = images[0].Name + cbReq.Seed = utils.InterfaceToString(info["seed"]) + cbReq.Success = true + cbReq.Progress = 100 + result <- cbReq + close(result) + + }() + + for { + select { + case value := <-result: + if value.Success { + logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName) + } + return + default: + var progressReq = map[string]any{ + "id_task": taskInfo.TaskId, + "id_live_preview": 1, + } + + var progressRes struct { + Active bool `json:"active"` + Queued bool `json:"queued"` + Completed bool `json:"completed"` + Progress float64 `json:"progress"` + Eta float64 `json:"eta"` + LivePreview string `json:"live_preview"` + IDLivePreview int `json:"id_live_preview"` + TextInfo interface{} `json:"textinfo"` + } + response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress") + var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true} + if err != nil { // TODO: 这里可以考虑设置失败重试次数 + logger.Error(err) + return + } + + if response.IsErrorState() { + bytes, _ := io.ReadAll(response.Body) + logger.Error(string(bytes)) + return + } + + cbReq.ImageData = progressRes.LivePreview + cbReq.Progress = int(progressRes.Progress * 100) + fmt.Println("Progress: ", progressRes.Progress) + fmt.Println("Image: ", progressRes.LivePreview) + time.Sleep(time.Second) + } + } +} diff --git a/api/service/sd_service.go b/api/service/sd/sd_service.go similarity index 50% rename from api/service/sd_service.go rename to api/service/sd/sd_service.go index a29b42ed..da1cc980 100644 --- a/api/service/sd_service.go +++ b/api/service/sd/sd_service.go @@ -1,45 +1,42 @@ -package service +package sd import ( "chatplus/core/types" + "chatplus/service/mj" "chatplus/store" "chatplus/store/model" "chatplus/utils" "context" - "errors" - "fmt" "github.com/go-redis/redis/v8" - "github.com/imroc/req/v3" "gorm.io/gorm" "time" ) // SD 绘画服务 -const SdRunningJobKey = "StableDiffusion_Running_Job" +const RunningJobKey = "StableDiffusion_Running_Job" -type SdService struct { - config types.ChatPlusExtConfig - client *req.Client +type Service struct { taskQueue *store.RedisQueue redis *redis.Client db *gorm.DB + Client *Client } -func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService { - return &SdService{ - config: appConfig.ExtConfig, - redis: client, +func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service { + return &Service{ + redis: redisCli, db: db, - taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client), - client: req.C().SetTimeout(30 * time.Second)} + Client: client, + taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli), + } } -func (s *SdService) Run() { +func (s *Service) Run() { logger.Info("Starting StableDiffusion job consumer.") ctx := context.Background() for { - _, err := s.redis.Get(ctx, SdRunningJobKey).Result() + _, err := s.redis.Get(ctx, RunningJobKey).Result() if err == nil { // 队列串行执行 time.Sleep(time.Second * 3) continue @@ -51,7 +48,7 @@ func (s *SdService) Run() { continue } logger.Infof("Consuming Task: %+v", task) - err = s.txt2img(task.Params) + err = s.Client.Txt2Img(task.Params) if err != nil { logger.Error("绘画任务执行失败:", err) if task.RetryCount <= 5 { @@ -65,31 +62,11 @@ func (s *SdService) Run() { // 更新任务的执行状态 s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true) // 锁定任务执行通道,直到任务超时(5分钟) - s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5) + s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5) } } -func (s *SdService) PushTask(task types.SdTask) { +func (s *Service) PushTask(task types.SdTask) { logger.Infof("add a new MidJourney Task: %+v", task) s.taskQueue.RPush(task) } - -func (s *SdService) txt2img(params types.SdParams) error { - logger.Infof("SD 绘画参数:%+v", params) - url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL) - var res types.BizVo - r, err := s.client.R(). - SetHeader("Authorization", s.config.Token). - SetHeader("Content-Type", "application/json"). - SetBody(params). - SetSuccessResult(&res).Post(url) - if err != nil || r.IsErrorState() { - return fmt.Errorf("%v%v", r.String(), err) - } - - if res.Code != types.Success { - return errors.New(res.Message) - } - - return nil -} diff --git a/api/service/sd/types.go b/api/service/sd/types.go new file mode 100644 index 00000000..b0942b89 --- /dev/null +++ b/api/service/sd/types.go @@ -0,0 +1,234 @@ +package sd + +import logger2 "chatplus/logger" + +var logger = logger2.GetLogger() + +type TaskInfo struct { + TaskId string `json:"task_id"` + Data interface{} `json:"data"` + EventData interface{} `json:"event_data"` + FnIndex int `json:"fn_index"` + SessionHash string `json:"session_hash"` +} + +type CBReq struct { + TaskId string + ImageName string + ImageData string + Progress int + Seed string + Success bool + Message string +} + +var ParamKeys = map[string]int{ + "task_id": 0, + "prompt": 1, + "negative_prompt": 2, + "steps": 4, + "sampler": 5, + "face_fix": 6, + "cfg_scale": 10, + "seed": 11, + "height": 17, + "width": 18, + "hd_fix": 19, + "hd_redraw_rate": 20, //高清修复重绘幅度 + "hd_scale": 21, // 高清修复放大倍数 + "hd_scale_alg": 22, // 高清修复放大算法 + "hd_sample_num": 23, // 高清修复采样次数 +} + +const Text2ImgParamTemplate = `[ +"", +"", +"", +[], +30, +"DPM++ SDE Karras", +false, +false, +1, +1, +7.5, +-1, +-1, +0, +0, +0, +false, +512, +512, +true, +0.7, +2, +"Latent", +10, +0, +0, +"Use same sampler", +"", +"", +[], +"None", +false, +"MultiDiffusion", +false, +true, +1024, +1024, +96, +96, +48, +4, +"None", +2, +false, +10, +1, +1, +64, +false, +false, +false, +false, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +0.4, +0.4, +0.2, +0.2, +"", +"", +"Background", +0.2, +-1, +false, +3072, +192, +true, +true, +true, +false, +null, +null, +null, +false, +"", +0.5, +true, +false, +"", +"Lerp", +false, +"🔄", +false, +false, +false, +false, +false, +false, +false, +false, +false, +"positive", +"comma", +0, +false, +false, +"", +"Seed", +"", +[], +"Nothing", +"", +[], +"Nothing", +"", +[], +true, +false, +false, +false, +0, +null, +null, +false, +null, +null, +false, +null, +null, +false, +50 +]` diff --git a/api/service/wx/bot.go b/api/service/wx/bot.go new file mode 100644 index 00000000..a5564999 --- /dev/null +++ b/api/service/wx/bot.go @@ -0,0 +1,87 @@ +package wx + +import ( + logger2 "chatplus/logger" + "chatplus/store/model" + "github.com/eatmoreapple/openwechat" + "github.com/skip2/go-qrcode" + "gorm.io/gorm" +) + +// 微信收款机器人 +var logger = logger2.GetLogger() + +type Bot struct { + bot *openwechat.Bot + token string + db *gorm.DB +} + +func NewWeChatBot(db *gorm.DB) *Bot { + bot := openwechat.DefaultBot(openwechat.Desktop) + return &Bot{ + bot: bot, + db: db, + } +} + +func (b *Bot) Run() error { + logger.Info("Starting WeChat Bot...") + + // set message handler + b.bot.MessageHandler = func(msg *openwechat.Message) { + b.messageHandler(msg) + } + // scan code login callback + b.bot.UUIDCallback = b.qrCodeCallBack + + err := b.bot.Login() + if err != nil { + return err + } + + logger.Info("微信登录成功!") + return nil +} + +// message handler +func (b *Bot) messageHandler(msg *openwechat.Message) { + sender, err := msg.Sender() + if err != nil { + return + } + + // 只处理微信支付的推送消息 + if sender.NickName == "微信支付" || + msg.MsgType == openwechat.MsgTypeApp || + msg.AppMsgType == openwechat.AppMsgTypeUrl { + // 解析支付金额 + message, err := parseTransactionMessage(msg.Content) + if err == nil { + transaction := extractTransaction(message) + logger.Infof("解析到收款信息:%+v", transaction) + var item model.Reward + res := b.db.Where("tx_id = ?", transaction.TransId).First(&item) + if res.Error == nil { + logger.Error("当前交易 ID 己经存在!") + return + } + + res = b.db.Create(&model.Reward{ + TxId: transaction.TransId, + Amount: transaction.Amount, + Remark: transaction.Remark, + Status: false, + }) + if res.Error != nil { + logger.Errorf("交易保存失败: %v", res.Error) + } + } + } +} + +func (b *Bot) qrCodeCallBack(uuid string) { + logger.Info("请使用微信扫描下面二维码登录") + q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium) + logger.Info(q.ToString(true)) +} diff --git a/api/service/wx/tranaction.go b/api/service/wx/tranaction.go new file mode 100644 index 00000000..ee06a9e3 --- /dev/null +++ b/api/service/wx/tranaction.go @@ -0,0 +1,68 @@ +package wx + +import ( + "encoding/xml" + "strconv" + "strings" +) + +// Message 转账消息 +type Message struct { + XMLName xml.Name `xml:"msg"` + AppMsg struct { + Des string `xml:"des"` + Url string `xml:"url"` + } `xml:"appmsg"` +} + +// Transaction 解析后的交易信息 +type Transaction struct { + TransId string `json:"trans_id"` // 微信转账交易 ID + Amount float64 `json:"amount"` // 微信转账交易金额 + Remark string `json:"remark"` // 转账备注 +} + +// 解析微信转账消息 +func parseTransactionMessage(xmlData string) (*Message, error) { + var msg Message + if err := xml.Unmarshal([]byte(xmlData), &msg); err != nil { + return nil, err + } + + return &msg, nil +} + +// 导出交易信息 +func extractTransaction(message *Message) Transaction { + var tx = Transaction{} + // 导出交易金额和备注 + lines := strings.Split(message.AppMsg.Des, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + // 解析收款金额 + prefix := "收款金额¥" + if strings.HasPrefix(line, prefix) { + if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil { + tx.Amount = value + continue + } + } + // 解析收款备注 + prefix = "付款方备注" + if strings.HasPrefix(line, prefix) { + tx.Remark = line[len(prefix):] + break + } + } + + // 解析交易 ID + index := strings.Index(message.AppMsg.Url, "trans_id=") + if index != -1 { + end := strings.LastIndex(message.AppMsg.Url, "&") + tx.TransId = strings.TrimSpace(message.AppMsg.Url[index+9 : end]) + } + return tx +} diff --git a/api/store/vo/sd_job.go b/api/store/vo/sd_job.go index b91cad69..c4ae2308 100644 --- a/api/store/vo/sd_job.go +++ b/api/store/vo/sd_job.go @@ -6,14 +6,14 @@ import ( ) type SdJob struct { - Id uint `json:"id"` - Type string `json:"type"` - UserId int `json:"user_id"` - TaskId string `json:"task_id"` - ImgURL string `json:"img_url"` - Params types.SdParams `json:"params"` - Progress int `json:"progress"` - Prompt string `json:"prompt"` - CreatedAt time.Time `json:"created_at"` - Started bool `json:"started"` + Id uint `json:"id"` + Type string `json:"type"` + UserId int `json:"user_id"` + TaskId string `json:"task_id"` + ImgURL string `json:"img_url"` + Params types.SdTaskParams `json:"params"` + Progress int `json:"progress"` + Prompt string `json:"prompt"` + CreatedAt time.Time `json:"created_at"` + Started bool `json:"started"` } diff --git a/api/utils/common.go b/api/utils/common.go index c1d7c64a..01fdd32c 100644 --- a/api/utils/common.go +++ b/api/utils/common.go @@ -138,3 +138,15 @@ func IntValue(str string, defaultValue int) int { } return value } + +func ForceCovert(src any, dst interface{}) error { + bytes, err := json.Marshal(src) + if err != nil { + return err + } + err = json.Unmarshal(bytes, dst) + if err != nil { + return err + } + return nil +}