mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
feat: migrate the chatgpt-plus-ext project code to this project
This commit is contained in:
parent
d51a724ade
commit
75c5ebbffa
@ -2,7 +2,7 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service/function"
|
"chatplus/service/fun"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
@ -33,11 +33,10 @@ type AppServer struct {
|
|||||||
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
|
ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
|
||||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||||
Functions map[string]function.Function
|
Functions map[string]fun.Function
|
||||||
MjTaskClients *types.LMap[string, *types.WsClient]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
|
func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
gin.DefaultWriter = io.Discard
|
gin.DefaultWriter = io.Discard
|
||||||
return &AppServer{
|
return &AppServer{
|
||||||
@ -48,7 +47,6 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio
|
|||||||
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
ChatSession: types.NewLMap[string, *types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
MjTaskClients: types.NewLMap[string, *types.WsClient](),
|
|
||||||
Functions: functions,
|
Functions: functions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,6 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
MaxAge: 86400,
|
MaxAge: 86400,
|
||||||
},
|
},
|
||||||
ApiConfig: types.ChatPlusApiConfig{},
|
ApiConfig: types.ChatPlusApiConfig{},
|
||||||
ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)},
|
|
||||||
OSS: types.OSSConfig{
|
OSS: types.OSSConfig{
|
||||||
Active: "local",
|
Active: "local",
|
||||||
Local: types.LocalStorageConfig{
|
Local: types.LocalStorageConfig{
|
||||||
@ -34,6 +33,9 @@ func NewDefaultConfig() *types.AppConfig {
|
|||||||
BasePath: "./static/upload",
|
BasePath: "./static/upload",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
MjConfig: types.MidJourneyConfig{Enabled: false},
|
||||||
|
SdConfig: types.StableDiffusionConfig{Enabled: false},
|
||||||
|
WeChatBot: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,9 +17,10 @@ type AppConfig struct {
|
|||||||
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
|
||||||
AesEncryptKey string
|
AesEncryptKey string
|
||||||
SmsConfig AliYunSmsConfig // AliYun send message service config
|
SmsConfig AliYunSmsConfig // AliYun send message service config
|
||||||
ExtConfig ChatPlusExtConfig // ChatPlus extensions callback api config
|
|
||||||
|
|
||||||
OSS OSSConfig // OSS config
|
OSS OSSConfig // OSS config
|
||||||
|
MjConfig MidJourneyConfig // mj 绘画配置
|
||||||
|
WeChatBot bool // 是否启用微信机器人
|
||||||
|
SdConfig StableDiffusionConfig // sd 绘画配置
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatPlusApiConfig struct {
|
type ChatPlusApiConfig struct {
|
||||||
@ -28,9 +29,22 @@ type ChatPlusApiConfig struct {
|
|||||||
Token string
|
Token string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatPlusExtConfig struct {
|
type MidJourneyConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
UserToken string
|
||||||
|
BotToken string
|
||||||
|
GuildId string // Server ID
|
||||||
|
ChanelId string // Chanel ID
|
||||||
|
}
|
||||||
|
|
||||||
|
type WeChatConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type StableDiffusionConfig struct {
|
||||||
|
Enabled bool
|
||||||
ApiURL string
|
ApiURL string
|
||||||
Token string
|
ApiKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliYunSmsConfig struct {
|
type AliYunSmsConfig struct {
|
||||||
|
@ -33,14 +33,24 @@ type MjTask struct {
|
|||||||
ChatId string `json:"chat_id,omitempty"`
|
ChatId string `json:"chat_id,omitempty"`
|
||||||
RoleId int `json:"role_id,omitempty"`
|
RoleId int `json:"role_id,omitempty"`
|
||||||
Icon string `json:"icon,omitempty"`
|
Icon string `json:"icon,omitempty"`
|
||||||
Index int32 `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
MessageId string `json:"message_id,omitempty"`
|
MessageId string `json:"message_id,omitempty"`
|
||||||
MessageHash string `json:"message_hash,omitempty"`
|
MessageHash string `json:"message_hash,omitempty"`
|
||||||
RetryCount int `json:"retry_count"`
|
RetryCount int `json:"retry_count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SdParams stable diffusion 绘画参数
|
type SdTask struct {
|
||||||
type SdParams struct {
|
Id int `json:"id"`
|
||||||
|
SessionId string `json:"session_id"`
|
||||||
|
Src TaskSrc `json:"src"`
|
||||||
|
Type TaskType `json:"type"`
|
||||||
|
UserId int `json:"user_id"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Params SdTaskParams `json:"params"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SdTaskParams struct {
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
NegativePrompt string `json:"negative_prompt"`
|
NegativePrompt string `json:"negative_prompt"`
|
||||||
@ -57,14 +67,3 @@ type SdParams struct {
|
|||||||
HdScaleAlg string `json:"hd_scale_alg"`
|
HdScaleAlg string `json:"hd_scale_alg"`
|
||||||
HdSampleNum int `json:"hd_sample_num"`
|
HdSampleNum int `json:"hd_sample_num"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SdTask struct {
|
|
||||||
Id int `json:"id"`
|
|
||||||
SessionId string `json:"session_id"`
|
|
||||||
Src types.TaskSrc `json:"src"`
|
|
||||||
Type types.TaskType `json:"type"`
|
|
||||||
UserId int `json:"user_id"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Params types.SdParams `json:"params"`
|
|
||||||
RetryCount int `json:"retry_count"`
|
|
||||||
}
|
|
||||||
|
@ -5,6 +5,9 @@ go 1.19
|
|||||||
require (
|
require (
|
||||||
github.com/BurntSushi/toml v1.1.0
|
github.com/BurntSushi/toml v1.1.0
|
||||||
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
|
github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
|
||||||
|
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
|
||||||
|
github.com/bwmarrin/discordgo v0.27.1
|
||||||
|
github.com/eatmoreapple/openwechat v1.2.1
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0
|
github.com/golang-jwt/jwt/v5 v5.0.0
|
||||||
@ -14,6 +17,7 @@ require (
|
|||||||
github.com/minio/minio-go/v7 v7.0.62
|
github.com/minio/minio-go/v7 v7.0.62
|
||||||
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
|
github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480
|
||||||
github.com/qiniu/go-sdk/v7 v7.17.1
|
github.com/qiniu/go-sdk/v7 v7.17.1
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/syndtr/goleveldb v1.0.0
|
github.com/syndtr/goleveldb v1.0.0
|
||||||
go.uber.org/zap v1.23.0
|
go.uber.org/zap v1.23.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
@ -21,7 +25,6 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible // indirect
|
|
||||||
github.com/andybalholm/brotli v1.0.4 // indirect
|
github.com/andybalholm/brotli v1.0.4 // indirect
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
|
@ -7,6 +7,8 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
|
|||||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||||
|
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
|
||||||
|
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
@ -25,6 +27,8 @@ github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0
|
|||||||
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU=
|
||||||
|
github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
|
||||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||||
@ -75,6 +79,7 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
|
|||||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
@ -168,6 +173,8 @@ github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
|||||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
@ -209,6 +216,7 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
|||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||||
|
@ -150,7 +150,7 @@ func (h *ChatHandler) sendAzureMessage(
|
|||||||
content := data
|
content := data
|
||||||
if functionName == types.FuncMidJourney {
|
if functionName == types.FuncMidJourney {
|
||||||
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
||||||
h.App.MjTaskClients.Put(session.SessionId, ws)
|
h.mjService.ChatClients.Put(session.SessionId, ws)
|
||||||
// update user's img_calls
|
// update user's img_calls
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/mj"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
@ -30,10 +31,11 @@ type ChatHandler struct {
|
|||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
leveldb *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
|
mjService *mj.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
|
||||||
handler := ChatHandler{db: db, leveldb: levelDB, redis: redis}
|
handler := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service}
|
||||||
handler.App = app
|
handler.App = app
|
||||||
return &handler
|
return &handler
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/service"
|
"chatplus/service/mj"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
@ -21,28 +21,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TaskStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Stopped = TaskStatus("Stopped")
|
|
||||||
Finished = TaskStatus("Finished")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Image struct {
|
|
||||||
URL string `json:"url"`
|
|
||||||
ProxyURL string `json:"proxy_url"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Width int `json:"width"`
|
|
||||||
Height int `json:"height"`
|
|
||||||
Size int `json:"size"`
|
|
||||||
Hash string `json:"hash"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
mjService *service.MjService
|
mjService *mj.Service
|
||||||
uploaderManager *oss.UploaderManager
|
uploaderManager *oss.UploaderManager
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
clients *types.LMap[string, *types.WsClient]
|
clients *types.LMap[string, *types.WsClient]
|
||||||
@ -53,7 +36,7 @@ func NewMidJourneyHandler(
|
|||||||
client *redis.Client,
|
client *redis.Client,
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
manager *oss.UploaderManager,
|
manager *oss.UploaderManager,
|
||||||
mjService *service.MjService) *MidJourneyHandler {
|
mjService *mj.Service) *MidJourneyHandler {
|
||||||
h := MidJourneyHandler{
|
h := MidJourneyHandler{
|
||||||
redis: client,
|
redis: client,
|
||||||
db: db,
|
db: db,
|
||||||
@ -66,16 +49,6 @@ func NewMidJourneyHandler(
|
|||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
type mjNotifyData struct {
|
|
||||||
MessageId string `json:"message_id"`
|
|
||||||
ReferenceId string `json:"reference_id"`
|
|
||||||
Image Image `json:"image"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Status TaskStatus `json:"status"`
|
|
||||||
Progress int `json:"progress"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
// Client WebSocket 客户端,用于通知任务状态变更
|
||||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
func (h *MidJourneyHandler) Client(c *gin.Context) {
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
@ -92,189 +65,6 @@ func (h *MidJourneyHandler) Client(c *gin.Context) {
|
|||||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|
||||||
token := c.GetHeader("Authorization")
|
|
||||||
if token != h.App.Config.ExtConfig.Token {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var data mjNotifyData
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Debugf("收到 MidJourney 回调请求:%+v", data)
|
|
||||||
|
|
||||||
h.lock.Lock()
|
|
||||||
defer h.lock.Unlock()
|
|
||||||
|
|
||||||
err, finished := h.notifyHandler(c, data)
|
|
||||||
if err != nil {
|
|
||||||
resp.ERROR(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解除任务锁定
|
|
||||||
if finished && (data.Status == Finished || data.Status == Stopped) {
|
|
||||||
h.redis.Del(c, service.MjRunningJobKey)
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MidJourneyHandler) notifyHandler(c *gin.Context, data mjNotifyData) (error, bool) {
|
|
||||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
|
||||||
if err != nil { // 过期任务,丢弃
|
|
||||||
logger.Warn("任务已过期:", err)
|
|
||||||
return nil, true
|
|
||||||
}
|
|
||||||
|
|
||||||
var task types.MjTask
|
|
||||||
err = utils.JsonDecode(taskString, &task)
|
|
||||||
if err != nil { // 非标准任务,丢弃
|
|
||||||
logger.Warn("任务解析失败:", err)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
res := h.db.Where("message_id = ?", data.MessageId).First(&job)
|
|
||||||
if res.Error == nil && data.Status == Finished {
|
|
||||||
logger.Warn("重复消息:", data.MessageId)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if task.Src == types.TaskSrcImg { // 绘画任务
|
|
||||||
var job model.MidJourneyJob
|
|
||||||
res := h.db.Where("id = ?", task.Id).First(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Warn("非法任务:", res.Error)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
job.MessageId = data.MessageId
|
|
||||||
job.ReferenceId = data.ReferenceId
|
|
||||||
job.Progress = data.Progress
|
|
||||||
job.Prompt = data.Prompt
|
|
||||||
job.Hash = data.Image.Hash
|
|
||||||
|
|
||||||
// 任务完成,将最终的图片下载下来
|
|
||||||
if data.Progress == 100 {
|
|
||||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download img: ", err.Error())
|
|
||||||
return err, false
|
|
||||||
}
|
|
||||||
job.ImgURL = imgURL
|
|
||||||
} else {
|
|
||||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
|
||||||
job.ImgURL = data.Image.URL
|
|
||||||
}
|
|
||||||
res = h.db.Updates(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Error("error with update job: ", res.Error)
|
|
||||||
return res.Error, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var jobVo vo.MidJourneyJob
|
|
||||||
err := utils.CopyObject(job, &jobVo)
|
|
||||||
if err == nil {
|
|
||||||
if data.Progress < 100 {
|
|
||||||
image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
|
|
||||||
if err == nil {
|
|
||||||
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 推送任务到前端
|
|
||||||
client := h.clients.Get(task.SessionId)
|
|
||||||
if client != nil {
|
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
|
||||||
wsClient := h.App.MjTaskClients.Get(task.SessionId)
|
|
||||||
if data.Status == Finished {
|
|
||||||
if wsClient != nil && data.ReferenceId != "" {
|
|
||||||
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
|
||||||
utils.ReplyMessage(wsClient, content)
|
|
||||||
}
|
|
||||||
// download image
|
|
||||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("error with download image: ", err)
|
|
||||||
if wsClient != nil && data.ReferenceId != "" {
|
|
||||||
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
|
||||||
utils.ReplyMessage(wsClient, content)
|
|
||||||
}
|
|
||||||
return err, false
|
|
||||||
}
|
|
||||||
|
|
||||||
tx := h.db.Begin()
|
|
||||||
data.Image.URL = imgURL
|
|
||||||
message := model.HistoryMessage{
|
|
||||||
UserId: uint(task.UserId),
|
|
||||||
ChatId: task.ChatId,
|
|
||||||
RoleId: uint(task.RoleId),
|
|
||||||
Type: types.MjMsg,
|
|
||||||
Icon: task.Icon,
|
|
||||||
Content: utils.JsonEncode(data),
|
|
||||||
Tokens: 0,
|
|
||||||
UseContext: false,
|
|
||||||
}
|
|
||||||
res = tx.Create(&message)
|
|
||||||
if res.Error != nil {
|
|
||||||
return res.Error, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// save the job
|
|
||||||
job.UserId = task.UserId
|
|
||||||
job.Type = task.Type.String()
|
|
||||||
job.MessageId = data.MessageId
|
|
||||||
job.ReferenceId = data.ReferenceId
|
|
||||||
job.Prompt = data.Prompt
|
|
||||||
job.ImgURL = imgURL
|
|
||||||
job.Progress = data.Progress
|
|
||||||
job.Hash = data.Image.Hash
|
|
||||||
job.CreatedAt = time.Now()
|
|
||||||
res = tx.Create(&job)
|
|
||||||
if res.Error != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
return res.Error, false
|
|
||||||
}
|
|
||||||
tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
if wsClient == nil { // 客户端断线,则丢弃
|
|
||||||
logger.Errorf("Client is offline: %+v", data)
|
|
||||||
return nil, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.Status == Finished {
|
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
|
||||||
// 本次绘画完毕,移除客户端
|
|
||||||
h.App.MjTaskClients.Delete(task.SessionId)
|
|
||||||
} else {
|
|
||||||
// 使用代理临时转发图片
|
|
||||||
if data.Image.URL != "" {
|
|
||||||
image, err := utils.DownloadImage(data.Image.URL, h.App.Config.ProxyURL)
|
|
||||||
if err == nil {
|
|
||||||
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新用户剩余绘图次数
|
|
||||||
// TODO: 放大图片是否需要消耗绘图次数?
|
|
||||||
if data.Status == Finished {
|
|
||||||
h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
user, err := utils.GetLoginUser(c, h.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -376,7 +166,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) {
|
|||||||
|
|
||||||
type reqVo struct {
|
type reqVo struct {
|
||||||
Src string `json:"src"`
|
Src string `json:"src"`
|
||||||
Index int32 `json:"index"`
|
Index int `json:"index"`
|
||||||
MessageId string `json:"message_id"`
|
MessageId string `json:"message_id"`
|
||||||
MessageHash string `json:"message_hash"`
|
MessageHash string `json:"message_hash"`
|
||||||
SessionId string `json:"session_id"`
|
SessionId string `json:"session_id"`
|
||||||
@ -443,15 +233,16 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
|
|||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if src == types.TaskSrcChat {
|
||||||
wsClient := h.App.ChatClients.Get(data.SessionId)
|
wsClient := h.App.ChatClients.Get(data.SessionId)
|
||||||
if wsClient != nil {
|
if wsClient != nil {
|
||||||
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
||||||
utils.ReplyMessage(wsClient, content)
|
utils.ReplyMessage(wsClient, content)
|
||||||
if h.App.MjTaskClients.Get(data.SessionId) == nil {
|
if h.mjService.ChatClients.Get(data.SessionId) == nil {
|
||||||
h.App.MjTaskClients.Put(data.SessionId, wsClient)
|
h.mjService.ChatClients.Put(data.SessionId, wsClient)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -513,13 +304,15 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
|
|||||||
MessageHash: data.MessageHash,
|
MessageHash: data.MessageHash,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if src == types.TaskSrcChat {
|
||||||
// 从聊天窗口发送的请求,记录客户端信息
|
// 从聊天窗口发送的请求,记录客户端信息
|
||||||
wsClient := h.App.ChatClients.Get(data.SessionId)
|
wsClient := h.mjService.ChatClients.Get(data.SessionId)
|
||||||
if wsClient != nil {
|
if wsClient != nil {
|
||||||
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
|
||||||
utils.ReplyMessage(wsClient, content)
|
utils.ReplyMessage(wsClient, content)
|
||||||
if h.App.MjTaskClients.Get(data.SessionId) == nil {
|
if h.mjService.Clients.Get(data.SessionId) == nil {
|
||||||
h.App.MjTaskClients.Put(data.SessionId, wsClient)
|
h.mjService.Clients.Put(data.SessionId, wsClient)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
|
@ -150,7 +150,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
content := data
|
content := data
|
||||||
if functionName == types.FuncMidJourney {
|
if functionName == types.FuncMidJourney {
|
||||||
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
||||||
h.App.MjTaskClients.Put(session.SessionId, ws)
|
h.mjService.ChatClients.Put(session.SessionId, ws)
|
||||||
// update user's img_calls
|
// update user's img_calls
|
||||||
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
}
|
}
|
||||||
|
@ -22,50 +22,6 @@ func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
|
|||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RewardHandler) Notify(c *gin.Context) {
|
|
||||||
token := c.GetHeader("Authorization")
|
|
||||||
if token != h.App.Config.ExtConfig.Token {
|
|
||||||
resp.NotAuth(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var data struct {
|
|
||||||
TransId string `json:"trans_id"` // 微信转账交易 ID
|
|
||||||
Amount float64 `json:"amount"` // 微信转账交易金额
|
|
||||||
Remark string `json:"remark"` // 转账备注
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.Amount <= 0 {
|
|
||||||
resp.ERROR(c, "Amount should not be 0")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("收到众筹收款信息: %+v", data)
|
|
||||||
var item model.Reward
|
|
||||||
res := h.db.Where("tx_id = ?", data.TransId).First(&item)
|
|
||||||
if res.Error == nil {
|
|
||||||
resp.ERROR(c, "当前交易 ID 己经存在!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res = h.db.Create(&model.Reward{
|
|
||||||
TxId: data.TransId,
|
|
||||||
Amount: data.Amount,
|
|
||||||
Remark: data.Remark,
|
|
||||||
Status: false,
|
|
||||||
})
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.Errorf("交易保存失败: %v", res.Error)
|
|
||||||
resp.ERROR(c, "交易保存失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.SUCCESS(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify 打赏码核销
|
// Verify 打赏码核销
|
||||||
func (h *RewardHandler) Verify(c *gin.Context) {
|
func (h *RewardHandler) Verify(c *gin.Context) {
|
||||||
var data struct {
|
var data struct {
|
||||||
|
@ -1,315 +1,316 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
//
|
||||||
"chatplus/core"
|
//import (
|
||||||
"chatplus/core/types"
|
// "chatplus/core"
|
||||||
"chatplus/service"
|
// "chatplus/core/types"
|
||||||
"chatplus/service/oss"
|
// "chatplus/service"
|
||||||
"chatplus/store/model"
|
// "chatplus/service/oss"
|
||||||
"chatplus/store/vo"
|
// "chatplus/store/model"
|
||||||
"chatplus/utils"
|
// "chatplus/store/vo"
|
||||||
"chatplus/utils/resp"
|
// "chatplus/utils"
|
||||||
"encoding/base64"
|
// "chatplus/utils/resp"
|
||||||
"fmt"
|
// "encoding/base64"
|
||||||
"github.com/gin-gonic/gin"
|
// "fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
// "github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
// "github.com/go-redis/redis/v8"
|
||||||
"gorm.io/gorm"
|
// "github.com/gorilla/websocket"
|
||||||
"net/http"
|
// "gorm.io/gorm"
|
||||||
"strings"
|
// "net/http"
|
||||||
"sync"
|
// "strings"
|
||||||
"time"
|
// "sync"
|
||||||
)
|
// "time"
|
||||||
|
//)
|
||||||
type SdJobHandler struct {
|
//
|
||||||
BaseHandler
|
//type SdJobHandler struct {
|
||||||
redis *redis.Client
|
// BaseHandler
|
||||||
db *gorm.DB
|
// redis *redis.Client
|
||||||
mjService *service.MjService
|
// db *gorm.DB
|
||||||
uploaderManager *oss.UploaderManager
|
// mjService *service.MjService
|
||||||
lock sync.Mutex
|
// uploaderManager *oss.UploaderManager
|
||||||
clients *types.LMap[string, *types.WsClient]
|
// lock sync.Mutex
|
||||||
}
|
// clients *types.LMap[string, *types.WsClient]
|
||||||
|
//}
|
||||||
func NewSdJobHandler(
|
//
|
||||||
app *core.AppServer,
|
//func NewSdJobHandler(
|
||||||
client *redis.Client,
|
// app *core.AppServer,
|
||||||
db *gorm.DB,
|
// client *redis.Client,
|
||||||
manager *oss.UploaderManager,
|
// db *gorm.DB,
|
||||||
mjService *service.MjService) *MidJourneyHandler {
|
// manager *oss.UploaderManager,
|
||||||
h := MidJourneyHandler{
|
// mjService *service.MjService) *MidJourneyHandler {
|
||||||
redis: client,
|
// h := MidJourneyHandler{
|
||||||
db: db,
|
// redis: client,
|
||||||
uploaderManager: manager,
|
// db: db,
|
||||||
lock: sync.Mutex{},
|
// uploaderManager: manager,
|
||||||
mjService: mjService,
|
// lock: sync.Mutex{},
|
||||||
clients: types.NewLMap[string, *types.WsClient](),
|
// mjService: mjService,
|
||||||
}
|
// clients: types.NewLMap[string, *types.WsClient](),
|
||||||
h.App = app
|
// }
|
||||||
return &h
|
// h.App = app
|
||||||
}
|
// return &h
|
||||||
|
//}
|
||||||
// Client WebSocket 客户端,用于通知任务状态变更
|
//
|
||||||
func (h *SdJobHandler) Client(c *gin.Context) {
|
//// Client WebSocket 客户端,用于通知任务状态变更
|
||||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
//func (h *SdJobHandler) Client(c *gin.Context) {
|
||||||
if err != nil {
|
// ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
logger.Error(err)
|
// if err != nil {
|
||||||
return
|
// logger.Error(err)
|
||||||
}
|
// return
|
||||||
|
// }
|
||||||
sessionId := c.Query("session_id")
|
//
|
||||||
client := types.NewWsClient(ws)
|
// sessionId := c.Query("session_id")
|
||||||
// 删除旧的连接
|
// client := types.NewWsClient(ws)
|
||||||
h.clients.Delete(sessionId)
|
// // 删除旧的连接
|
||||||
h.clients.Put(sessionId, client)
|
// h.clients.Delete(sessionId)
|
||||||
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
// h.clients.Put(sessionId, client)
|
||||||
}
|
// logger.Infof("New websocket connected, IP: %s", c.ClientIP())
|
||||||
|
//}
|
||||||
type sdNotifyData struct {
|
//
|
||||||
TaskId string
|
//type sdNotifyData struct {
|
||||||
ImageName string
|
// TaskId string
|
||||||
ImageData string
|
// ImageName string
|
||||||
Progress int
|
// ImageData string
|
||||||
Seed string
|
// Progress int
|
||||||
Success bool
|
// Seed string
|
||||||
Message string
|
// Success bool
|
||||||
}
|
// Message string
|
||||||
|
//}
|
||||||
func (h *SdJobHandler) Notify(c *gin.Context) {
|
//
|
||||||
token := c.GetHeader("Authorization")
|
//func (h *SdJobHandler) Notify(c *gin.Context) {
|
||||||
if token != h.App.Config.ExtConfig.Token {
|
// token := c.GetHeader("Authorization")
|
||||||
resp.NotAuth(c)
|
// if token != h.App.Config.ExtConfig.Token {
|
||||||
return
|
// resp.NotAuth(c)
|
||||||
}
|
// return
|
||||||
var data sdNotifyData
|
// }
|
||||||
if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
|
// var data sdNotifyData
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
// if err := c.ShouldBindJSON(&data); err != nil || data.TaskId == "" {
|
||||||
return
|
// resp.ERROR(c, types.InvalidArgs)
|
||||||
}
|
// return
|
||||||
logger.Debugf("收到 MidJourney 回调请求:%+v", data)
|
// }
|
||||||
|
// logger.Debugf("收到 MidJourney 回调请求:%+v", data)
|
||||||
h.lock.Lock()
|
//
|
||||||
defer h.lock.Unlock()
|
// h.lock.Lock()
|
||||||
|
// defer h.lock.Unlock()
|
||||||
err, finished := h.notifyHandler(c, data)
|
//
|
||||||
if err != nil {
|
// err, finished := h.notifyHandler(c, data)
|
||||||
resp.ERROR(c, err.Error())
|
// if err != nil {
|
||||||
return
|
// resp.ERROR(c, err.Error())
|
||||||
}
|
// return
|
||||||
|
// }
|
||||||
// 解除任务锁定
|
//
|
||||||
if finished && (data.Progress == 100) {
|
// // 解除任务锁定
|
||||||
h.redis.Del(c, service.MjRunningJobKey)
|
// if finished && (data.Progress == 100) {
|
||||||
}
|
// h.redis.Del(c, service.MjRunningJobKey)
|
||||||
resp.SUCCESS(c)
|
// }
|
||||||
|
// resp.SUCCESS(c)
|
||||||
}
|
//
|
||||||
|
//}
|
||||||
func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
|
//
|
||||||
taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
//func (h *SdJobHandler) notifyHandler(c *gin.Context, data sdNotifyData) (error, bool) {
|
||||||
if err != nil { // 过期任务,丢弃
|
// taskString, err := h.redis.Get(c, service.MjRunningJobKey).Result()
|
||||||
logger.Warn("任务已过期:", err)
|
// if err != nil { // 过期任务,丢弃
|
||||||
return nil, true
|
// logger.Warn("任务已过期:", err)
|
||||||
}
|
// return nil, true
|
||||||
|
// }
|
||||||
var task types.SdTask
|
//
|
||||||
err = utils.JsonDecode(taskString, &task)
|
// var task types.SdTask
|
||||||
if err != nil { // 非标准任务,丢弃
|
// err = utils.JsonDecode(taskString, &task)
|
||||||
logger.Warn("任务解析失败:", err)
|
// if err != nil { // 非标准任务,丢弃
|
||||||
return nil, false
|
// logger.Warn("任务解析失败:", err)
|
||||||
}
|
// return nil, false
|
||||||
|
// }
|
||||||
var job model.SdJob
|
//
|
||||||
res := h.db.Where("id = ?", task.Id).First(&job)
|
// var job model.SdJob
|
||||||
if res.Error != nil {
|
// res := h.db.Where("id = ?", task.Id).First(&job)
|
||||||
logger.Warn("非法任务:", res.Error)
|
// if res.Error != nil {
|
||||||
return nil, false
|
// logger.Warn("非法任务:", res.Error)
|
||||||
}
|
// return nil, false
|
||||||
job.Params = utils.JsonEncode(task.Params)
|
// }
|
||||||
job.ReferenceId = data.ImageData
|
// job.Params = utils.JsonEncode(task.Params)
|
||||||
job.Progress = data.Progress
|
// job.ReferenceId = data.ImageData
|
||||||
job.Prompt = data.Prompt
|
// job.Progress = data.Progress
|
||||||
job.Hash = data.Image.Hash
|
// job.Prompt = data.Prompt
|
||||||
|
// job.Hash = data.Image.Hash
|
||||||
// 任务完成,将最终的图片下载下来
|
//
|
||||||
if data.Progress == 100 {
|
// // 任务完成,将最终的图片下载下来
|
||||||
imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
// if data.Progress == 100 {
|
||||||
if err != nil {
|
// imgURL, err := h.uploaderManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||||
logger.Error("error with download img: ", err.Error())
|
// if err != nil {
|
||||||
return err, false
|
// logger.Error("error with download img: ", err.Error())
|
||||||
}
|
// return err, false
|
||||||
job.ImgURL = imgURL
|
// }
|
||||||
} else {
|
// job.ImgURL = imgURL
|
||||||
// 临时图片直接保存,访问的时候使用代理进行转发
|
// } else {
|
||||||
job.ImgURL = data.Image.URL
|
// // 临时图片直接保存,访问的时候使用代理进行转发
|
||||||
}
|
// job.ImgURL = data.Image.URL
|
||||||
res = h.db.Updates(&job)
|
// }
|
||||||
if res.Error != nil {
|
// res = h.db.Updates(&job)
|
||||||
logger.Error("error with update job: ", res.Error)
|
// if res.Error != nil {
|
||||||
return res.Error, false
|
// logger.Error("error with update job: ", res.Error)
|
||||||
}
|
// return res.Error, false
|
||||||
|
// }
|
||||||
var jobVo vo.MidJourneyJob
|
//
|
||||||
err := utils.CopyObject(job, &jobVo)
|
// var jobVo vo.MidJourneyJob
|
||||||
if err == nil {
|
// err := utils.CopyObject(job, &jobVo)
|
||||||
if data.Progress < 100 {
|
// if err == nil {
|
||||||
image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
|
// if data.Progress < 100 {
|
||||||
if err == nil {
|
// image, err := utils.DownloadImage(jobVo.ImgURL, h.App.Config.ProxyURL)
|
||||||
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
// if err == nil {
|
||||||
}
|
// jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
// 推送任务到前端
|
//
|
||||||
client := h.clients.Get(task.SessionId)
|
// // 推送任务到前端
|
||||||
if client != nil {
|
// client := h.clients.Get(task.SessionId)
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
// if client != nil {
|
||||||
}
|
// utils.ReplyChunkMessage(client, jobVo)
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
// 更新用户剩余绘图次数
|
//
|
||||||
if data.Progress == 100 {
|
// // 更新用户剩余绘图次数
|
||||||
h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
// if data.Progress == 100 {
|
||||||
}
|
// h.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
|
// }
|
||||||
return nil, true
|
//
|
||||||
}
|
// return nil, true
|
||||||
|
//}
|
||||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
//
|
||||||
user, err := utils.GetLoginUser(c, h.db)
|
//func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
|
||||||
if err != nil {
|
// user, err := utils.GetLoginUser(c, h.db)
|
||||||
resp.NotAuth(c)
|
// if err != nil {
|
||||||
return false
|
// resp.NotAuth(c)
|
||||||
}
|
// return false
|
||||||
|
// }
|
||||||
if user.ImgCalls <= 0 {
|
//
|
||||||
resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
// if user.ImgCalls <= 0 {
|
||||||
return false
|
// resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
|
||||||
}
|
// return false
|
||||||
|
// }
|
||||||
return true
|
//
|
||||||
|
// return true
|
||||||
}
|
//
|
||||||
|
//}
|
||||||
// Image 创建一个绘画任务
|
//
|
||||||
func (h *SdJobHandler) Image(c *gin.Context) {
|
//// Image 创建一个绘画任务
|
||||||
var data struct {
|
//func (h *SdJobHandler) Image(c *gin.Context) {
|
||||||
SessionId string `json:"session_id"`
|
// var data struct {
|
||||||
Prompt string `json:"prompt"`
|
// SessionId string `json:"session_id"`
|
||||||
Rate string `json:"rate"`
|
// Prompt string `json:"prompt"`
|
||||||
Model string `json:"model"`
|
// Rate string `json:"rate"`
|
||||||
Chaos int `json:"chaos"`
|
// Model string `json:"model"`
|
||||||
Raw bool `json:"raw"`
|
// Chaos int `json:"chaos"`
|
||||||
Seed int64 `json:"seed"`
|
// Raw bool `json:"raw"`
|
||||||
Stylize int `json:"stylize"`
|
// Seed int64 `json:"seed"`
|
||||||
Img string `json:"img"`
|
// Stylize int `json:"stylize"`
|
||||||
Weight float32 `json:"weight"`
|
// Img string `json:"img"`
|
||||||
}
|
// Weight float32 `json:"weight"`
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
// }
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
// if err := c.ShouldBindJSON(&data); err != nil {
|
||||||
return
|
// resp.ERROR(c, types.InvalidArgs)
|
||||||
}
|
// return
|
||||||
if !h.checkLimits(c) {
|
// }
|
||||||
return
|
// if !h.checkLimits(c) {
|
||||||
}
|
// return
|
||||||
|
// }
|
||||||
var prompt = data.Prompt
|
//
|
||||||
if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
// var prompt = data.Prompt
|
||||||
prompt += " --ar " + data.Rate
|
// if data.Rate != "" && !strings.Contains(prompt, "--ar") {
|
||||||
}
|
// prompt += " --ar " + data.Rate
|
||||||
if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
|
// }
|
||||||
prompt += fmt.Sprintf(" --seed %d", data.Seed)
|
// if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
|
||||||
}
|
// prompt += fmt.Sprintf(" --seed %d", data.Seed)
|
||||||
if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
|
// }
|
||||||
prompt += fmt.Sprintf(" --s %d", data.Stylize)
|
// if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
|
||||||
}
|
// prompt += fmt.Sprintf(" --s %d", data.Stylize)
|
||||||
if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
|
// }
|
||||||
prompt += fmt.Sprintf(" --c %d", data.Chaos)
|
// if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
|
||||||
}
|
// prompt += fmt.Sprintf(" --c %d", data.Chaos)
|
||||||
if data.Img != "" {
|
// }
|
||||||
prompt = fmt.Sprintf("%s %s", data.Img, prompt)
|
// if data.Img != "" {
|
||||||
if data.Weight > 0 {
|
// prompt = fmt.Sprintf("%s %s", data.Img, prompt)
|
||||||
prompt += fmt.Sprintf(" --iw %f", data.Weight)
|
// if data.Weight > 0 {
|
||||||
}
|
// prompt += fmt.Sprintf(" --iw %f", data.Weight)
|
||||||
}
|
// }
|
||||||
if data.Raw {
|
// }
|
||||||
prompt += " --style raw"
|
// if data.Raw {
|
||||||
}
|
// prompt += " --style raw"
|
||||||
if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
|
// }
|
||||||
prompt += data.Model
|
// if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
|
||||||
}
|
// prompt += data.Model
|
||||||
|
// }
|
||||||
idValue, _ := c.Get(types.LoginUserID)
|
//
|
||||||
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
// idValue, _ := c.Get(types.LoginUserID)
|
||||||
job := model.MidJourneyJob{
|
// userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
|
||||||
Type: service.Image.String(),
|
// job := model.MidJourneyJob{
|
||||||
UserId: userId,
|
// Type: service.Image.String(),
|
||||||
Progress: 0,
|
// UserId: userId,
|
||||||
Prompt: prompt,
|
// Progress: 0,
|
||||||
CreatedAt: time.Now(),
|
// Prompt: prompt,
|
||||||
}
|
// CreatedAt: time.Now(),
|
||||||
if res := h.db.Create(&job); res.Error != nil {
|
// }
|
||||||
resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
// if res := h.db.Create(&job); res.Error != nil {
|
||||||
return
|
// resp.ERROR(c, "添加任务失败:"+res.Error.Error())
|
||||||
}
|
// return
|
||||||
|
// }
|
||||||
h.mjService.PushTask(service.MjTask{
|
//
|
||||||
Id: int(job.Id),
|
// h.mjService.PushTask(service.MjTask{
|
||||||
SessionId: data.SessionId,
|
// Id: int(job.Id),
|
||||||
Src: service.TaskSrcImg,
|
// SessionId: data.SessionId,
|
||||||
Type: service.Image,
|
// Src: service.TaskSrcImg,
|
||||||
Prompt: prompt,
|
// Type: service.Image,
|
||||||
UserId: userId,
|
// Prompt: prompt,
|
||||||
})
|
// UserId: userId,
|
||||||
|
// })
|
||||||
var jobVo vo.MidJourneyJob
|
//
|
||||||
err := utils.CopyObject(job, &jobVo)
|
// var jobVo vo.MidJourneyJob
|
||||||
if err == nil {
|
// err := utils.CopyObject(job, &jobVo)
|
||||||
// 推送任务到前端
|
// if err == nil {
|
||||||
client := h.clients.Get(data.SessionId)
|
// // 推送任务到前端
|
||||||
if client != nil {
|
// client := h.clients.Get(data.SessionId)
|
||||||
utils.ReplyChunkMessage(client, jobVo)
|
// if client != nil {
|
||||||
}
|
// utils.ReplyChunkMessage(client, jobVo)
|
||||||
}
|
// }
|
||||||
resp.SUCCESS(c)
|
// }
|
||||||
}
|
// resp.SUCCESS(c)
|
||||||
|
//}
|
||||||
// JobList 获取 MJ 任务列表
|
//
|
||||||
func (h *SdJobHandler) JobList(c *gin.Context) {
|
//// JobList 获取 MJ 任务列表
|
||||||
status := h.GetInt(c, "status", 0)
|
//func (h *SdJobHandler) JobList(c *gin.Context) {
|
||||||
var items []model.MidJourneyJob
|
// status := h.GetInt(c, "status", 0)
|
||||||
var res *gorm.DB
|
// var items []model.MidJourneyJob
|
||||||
userId, _ := c.Get(types.LoginUserID)
|
// var res *gorm.DB
|
||||||
if status == 1 {
|
// userId, _ := c.Get(types.LoginUserID)
|
||||||
res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
|
// if status == 1 {
|
||||||
} else {
|
// res = h.db.Where("user_id = ? AND progress = 100", userId).Order("id DESC").Find(&items)
|
||||||
res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
|
// } else {
|
||||||
}
|
// res = h.db.Where("user_id = ? AND progress < 100", userId).Order("id ASC").Find(&items)
|
||||||
if res.Error != nil {
|
// }
|
||||||
resp.ERROR(c, types.NoData)
|
// if res.Error != nil {
|
||||||
return
|
// resp.ERROR(c, types.NoData)
|
||||||
}
|
// return
|
||||||
|
// }
|
||||||
var jobs = make([]vo.MidJourneyJob, 0)
|
//
|
||||||
for _, item := range items {
|
// var jobs = make([]vo.MidJourneyJob, 0)
|
||||||
var job vo.MidJourneyJob
|
// for _, item := range items {
|
||||||
err := utils.CopyObject(item, &job)
|
// var job vo.MidJourneyJob
|
||||||
if err != nil {
|
// err := utils.CopyObject(item, &job)
|
||||||
continue
|
// if err != nil {
|
||||||
}
|
// continue
|
||||||
if item.Progress < 100 {
|
// }
|
||||||
// 30 分钟还没完成的任务直接删除
|
// if item.Progress < 100 {
|
||||||
if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
|
// // 30 分钟还没完成的任务直接删除
|
||||||
h.db.Delete(&item)
|
// if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
|
||||||
continue
|
// h.db.Delete(&item)
|
||||||
}
|
// continue
|
||||||
if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
// }
|
||||||
image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
// if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
|
||||||
if err == nil {
|
// image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
|
||||||
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
// if err == nil {
|
||||||
}
|
// job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
jobs = append(jobs, job)
|
// }
|
||||||
}
|
// jobs = append(jobs, job)
|
||||||
resp.SUCCESS(c, jobs)
|
// }
|
||||||
}
|
// resp.SUCCESS(c, jobs)
|
||||||
|
//}
|
||||||
|
35
api/main.go
35
api/main.go
@ -7,8 +7,10 @@ import (
|
|||||||
"chatplus/handler/admin"
|
"chatplus/handler/admin"
|
||||||
logger2 "chatplus/logger"
|
logger2 "chatplus/logger"
|
||||||
"chatplus/service"
|
"chatplus/service"
|
||||||
"chatplus/service/function"
|
"chatplus/service/fun"
|
||||||
|
"chatplus/service/mj"
|
||||||
"chatplus/service/oss"
|
"chatplus/service/oss"
|
||||||
|
"chatplus/service/wx"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"context"
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
@ -107,7 +109,7 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
|
|
||||||
// 创建函数
|
// 创建函数
|
||||||
fx.Provide(function.NewFunctions),
|
fx.Provide(fun.NewFunctions),
|
||||||
|
|
||||||
// 创建控制器
|
// 创建控制器
|
||||||
fx.Provide(handler.NewChatRoleHandler),
|
fx.Provide(handler.NewChatRoleHandler),
|
||||||
@ -135,13 +137,36 @@ func main() {
|
|||||||
return service.NewCaptchaService(config.ApiConfig)
|
return service.NewCaptchaService(config.ApiConfig)
|
||||||
}),
|
}),
|
||||||
fx.Provide(oss.NewUploaderManager),
|
fx.Provide(oss.NewUploaderManager),
|
||||||
fx.Provide(service.NewMjService),
|
fx.Provide(mj.NewService),
|
||||||
fx.Invoke(func(mjService *service.MjService) {
|
fx.Invoke(func(mjService *mj.Service) {
|
||||||
go func() {
|
go func() {
|
||||||
mjService.Run()
|
mjService.Run()
|
||||||
}()
|
}()
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
// 微信机器人服务
|
||||||
|
fx.Provide(wx.NewWeChatBot),
|
||||||
|
fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) {
|
||||||
|
if config.WeChatBot {
|
||||||
|
err := bot.Run()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("微信登录失败:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
// MidJourney 机器人
|
||||||
|
fx.Provide(mj.NewBot),
|
||||||
|
fx.Provide(mj.NewClient),
|
||||||
|
fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
|
||||||
|
if config.MjConfig.Enabled {
|
||||||
|
err := bot.Run()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("MidJourney 服务启动失败:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
// 注册路由
|
// 注册路由
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||||
group := s.Engine.Group("/api/role/")
|
group := s.Engine.Group("/api/role/")
|
||||||
@ -185,12 +210,10 @@ func main() {
|
|||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
|
||||||
group := s.Engine.Group("/api/reward/")
|
group := s.Engine.Group("/api/reward/")
|
||||||
group.POST("notify", h.Notify)
|
|
||||||
group.POST("verify", h.Verify)
|
group.POST("verify", h.Verify)
|
||||||
}),
|
}),
|
||||||
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
|
||||||
group := s.Engine.Group("/api/mj/")
|
group := s.Engine.Group("/api/mj/")
|
||||||
group.POST("notify", h.Notify)
|
|
||||||
group.POST("image", h.Image)
|
group.POST("image", h.Image)
|
||||||
group.POST("upscale", h.Upscale)
|
group.POST("upscale", h.Upscale)
|
||||||
group.POST("variation", h.Variation)
|
group.POST("variation", h.Variation)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
package function
|
package fun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/service"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/mj"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -9,10 +10,10 @@ import (
|
|||||||
|
|
||||||
type FuncMidJourney struct {
|
type FuncMidJourney struct {
|
||||||
name string
|
name string
|
||||||
service *service.MjService
|
service *mj.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
|
func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney {
|
||||||
return FuncMidJourney{
|
return FuncMidJourney{
|
||||||
name: "MidJourney AI 绘画",
|
name: "MidJourney AI 绘画",
|
||||||
service: mjService}
|
service: mjService}
|
||||||
@ -21,10 +22,10 @@ func NewMidJourneyFunc(mjService *service.MjService) FuncMidJourney {
|
|||||||
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
||||||
logger.Infof("MJ 绘画参数:%+v", params)
|
logger.Infof("MJ 绘画参数:%+v", params)
|
||||||
prompt := utils.InterfaceToString(params["prompt"])
|
prompt := utils.InterfaceToString(params["prompt"])
|
||||||
f.service.PushTask(service.MjTask{
|
f.service.PushTask(types.MjTask{
|
||||||
SessionId: utils.InterfaceToString(params["session_id"]),
|
SessionId: utils.InterfaceToString(params["session_id"]),
|
||||||
Src: service.TaskSrcChat,
|
Src: types.TaskSrcChat,
|
||||||
Type: service.Image,
|
Type: types.TaskImage,
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
|
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
|
||||||
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
|
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
|
@ -1,9 +1,9 @@
|
|||||||
package function
|
package fun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
logger2 "chatplus/logger"
|
logger2 "chatplus/logger"
|
||||||
"chatplus/service"
|
"chatplus/service/mj"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Function interface {
|
type Function interface {
|
||||||
@ -29,7 +29,7 @@ type dataItem struct {
|
|||||||
Remark string `json:"remark"`
|
Remark string `json:"remark"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFunctions(config *types.AppConfig, mjService *service.MjService) map[string]Function {
|
func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function {
|
||||||
return map[string]Function{
|
return map[string]Function{
|
||||||
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
|
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
|
||||||
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
|
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
|
@ -1,4 +1,4 @@
|
|||||||
package function
|
package fun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
@ -1,4 +1,4 @@
|
|||||||
package function
|
package fun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
@ -1,4 +1,4 @@
|
|||||||
package function
|
package fun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
213
api/service/mj/bot.go
Normal file
213
api/service/mj/bot.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
logger2 "chatplus/logger"
|
||||||
|
"chatplus/utils"
|
||||||
|
"github.com/bwmarrin/discordgo"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MidJourney 机器人
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
type Bot struct {
|
||||||
|
config *types.MidJourneyConfig
|
||||||
|
bot *discordgo.Session
|
||||||
|
service *Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
|
||||||
|
discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ProxyURL != "" {
|
||||||
|
proxy, _ := url.Parse(config.ProxyURL)
|
||||||
|
discord.Client = &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxy),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
discord.Dialer = &websocket.Dialer{
|
||||||
|
Proxy: http.ProxyURL(proxy),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Bot{
|
||||||
|
config: &config.MjConfig,
|
||||||
|
bot: discord,
|
||||||
|
service: service,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bot) Run() error {
|
||||||
|
b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
|
||||||
|
b.bot.AddHandler(b.messageCreate)
|
||||||
|
b.bot.AddHandler(b.messageUpdate)
|
||||||
|
|
||||||
|
logger.Info("Starting MidJourney Bot...")
|
||||||
|
err := b.bot.Open()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error opening Discord connection:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Info("Starting MidJourney Bot successfully!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Start = TaskStatus("Started")
|
||||||
|
Running = TaskStatus("Running")
|
||||||
|
Stopped = TaskStatus("Stopped")
|
||||||
|
Finished = TaskStatus("Finished")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Image struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
ProxyURL string `json:"proxy_url"`
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
Size int `json:"size"`
|
||||||
|
Hash string `json:"hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
|
// ignore messages for other channels
|
||||||
|
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// ignore messages for self
|
||||||
|
if m.Author.ID == s.State.User.ID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("CREATE: %s", utils.JsonEncode(m))
|
||||||
|
var referenceId = ""
|
||||||
|
if m.ReferencedMessage != nil {
|
||||||
|
referenceId = m.ReferencedMessage.ID
|
||||||
|
}
|
||||||
|
if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
|
||||||
|
// parse content
|
||||||
|
req := CBReq{
|
||||||
|
MessageId: m.ID,
|
||||||
|
ReferenceId: referenceId,
|
||||||
|
Prompt: extractPrompt(m.Content),
|
||||||
|
Content: m.Content,
|
||||||
|
Progress: 0,
|
||||||
|
Status: Start}
|
||||||
|
b.service.Notify(req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||||
|
// ignore messages for other channels
|
||||||
|
if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// ignore messages for self
|
||||||
|
if m.Author.ID == s.State.User.ID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
|
||||||
|
|
||||||
|
var referenceId = ""
|
||||||
|
if m.ReferencedMessage != nil {
|
||||||
|
referenceId = m.ReferencedMessage.ID
|
||||||
|
}
|
||||||
|
if strings.Contains(m.Content, "(Stopped)") {
|
||||||
|
req := CBReq{
|
||||||
|
MessageId: m.ID,
|
||||||
|
ReferenceId: referenceId,
|
||||||
|
Prompt: extractPrompt(m.Content),
|
||||||
|
Content: m.Content,
|
||||||
|
Progress: extractProgress(m.Content),
|
||||||
|
Status: Stopped}
|
||||||
|
b.service.Notify(req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
|
||||||
|
progress := extractProgress(content)
|
||||||
|
var status TaskStatus
|
||||||
|
if progress == 100 {
|
||||||
|
status = Finished
|
||||||
|
} else {
|
||||||
|
status = Running
|
||||||
|
}
|
||||||
|
for _, attachment := range attachments {
|
||||||
|
if attachment.Width == 0 || attachment.Height == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
image := Image{
|
||||||
|
URL: attachment.URL,
|
||||||
|
Height: attachment.Height,
|
||||||
|
ProxyURL: attachment.ProxyURL,
|
||||||
|
Width: attachment.Width,
|
||||||
|
Size: attachment.Size,
|
||||||
|
Filename: attachment.Filename,
|
||||||
|
Hash: extractHashFromFilename(attachment.Filename),
|
||||||
|
}
|
||||||
|
req := CBReq{
|
||||||
|
MessageId: messageId,
|
||||||
|
ReferenceId: referenceId,
|
||||||
|
Image: image,
|
||||||
|
Prompt: extractPrompt(content),
|
||||||
|
Content: content,
|
||||||
|
Progress: progress,
|
||||||
|
Status: status,
|
||||||
|
}
|
||||||
|
b.service.Notify(req)
|
||||||
|
break // only get one image
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract prompt from string
|
||||||
|
func extractPrompt(input string) string {
|
||||||
|
pattern := `\*\*(.*?)\*\*`
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
matches := re.FindStringSubmatch(input)
|
||||||
|
if len(matches) > 1 {
|
||||||
|
return strings.TrimSpace(matches[1])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractProgress(input string) int {
|
||||||
|
pattern := `\((\d+)\%\)`
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
matches := re.FindStringSubmatch(input)
|
||||||
|
if len(matches) > 1 {
|
||||||
|
return utils.IntValue(matches[1], 0)
|
||||||
|
}
|
||||||
|
return 100
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractHashFromFilename(filename string) string {
|
||||||
|
if !strings.HasSuffix(filename, ".png") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
index := strings.LastIndex(filename, "_")
|
||||||
|
if index != -1 {
|
||||||
|
return filename[index+1 : len(filename)-4]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
144
api/service/mj/client.go
Normal file
144
api/service/mj/client.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"fmt"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MidJourney client
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
client *req.Client
|
||||||
|
config *types.MidJourneyConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(config *types.AppConfig) *Client {
|
||||||
|
client := req.C().SetTimeout(10 * time.Second)
|
||||||
|
// set proxy URL
|
||||||
|
if config.ProxyURL != "" {
|
||||||
|
client.SetProxyURL(config.ProxyURL)
|
||||||
|
}
|
||||||
|
return &Client{client: client, config: &config.MjConfig}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Imagine(prompt string) error {
|
||||||
|
interactionsReq := &InteractionsRequest{
|
||||||
|
Type: 2,
|
||||||
|
ApplicationID: ApplicationID,
|
||||||
|
GuildID: c.config.GuildId,
|
||||||
|
ChannelID: c.config.ChanelId,
|
||||||
|
SessionID: SessionID,
|
||||||
|
Data: map[string]any{
|
||||||
|
"version": "1118961510123847772",
|
||||||
|
"id": "938956540159881230",
|
||||||
|
"name": "imagine",
|
||||||
|
"type": "1",
|
||||||
|
"options": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": 3,
|
||||||
|
"name": "prompt",
|
||||||
|
"value": prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"application_command": map[string]any{
|
||||||
|
"id": "938956540159881230",
|
||||||
|
"application_id": ApplicationID,
|
||||||
|
"version": "1118961510123847772",
|
||||||
|
"default_permission": true,
|
||||||
|
"default_member_permissions": nil,
|
||||||
|
"type": 1,
|
||||||
|
"nsfw": false,
|
||||||
|
"name": "imagine",
|
||||||
|
"description": "Create images with Midjourney",
|
||||||
|
"dm_permission": true,
|
||||||
|
"options": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": 3,
|
||||||
|
"name": "prompt",
|
||||||
|
"description": "The prompt to imagine",
|
||||||
|
"required": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"attachments": []any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "https://discord.com/api/v9/interactions"
|
||||||
|
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(interactionsReq).
|
||||||
|
Post(url)
|
||||||
|
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return fmt.Errorf("error with http request: %w%v", err, r.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upscale 放大指定的图片
|
||||||
|
func (c *Client) Upscale(index int, messageId string, hash string) error {
|
||||||
|
flags := 0
|
||||||
|
interactionsReq := &InteractionsRequest{
|
||||||
|
Type: 3,
|
||||||
|
ApplicationID: ApplicationID,
|
||||||
|
GuildID: c.config.GuildId,
|
||||||
|
ChannelID: c.config.ChanelId,
|
||||||
|
MessageFlags: &flags,
|
||||||
|
MessageID: &messageId,
|
||||||
|
SessionID: SessionID,
|
||||||
|
Data: map[string]any{
|
||||||
|
"component_type": 2,
|
||||||
|
"custom_id": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
|
||||||
|
},
|
||||||
|
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "https://discord.com/api/v9/interactions"
|
||||||
|
var res InteractionsResult
|
||||||
|
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(interactionsReq).
|
||||||
|
SetErrorResult(&res).
|
||||||
|
Post(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
|
||||||
|
func (c *Client) Variation(index int, messageId string, hash string) error {
|
||||||
|
flags := 0
|
||||||
|
interactionsReq := &InteractionsRequest{
|
||||||
|
Type: 3,
|
||||||
|
ApplicationID: ApplicationID,
|
||||||
|
GuildID: c.config.GuildId,
|
||||||
|
ChannelID: c.config.ChanelId,
|
||||||
|
MessageFlags: &flags,
|
||||||
|
MessageID: &messageId,
|
||||||
|
SessionID: SessionID,
|
||||||
|
Data: map[string]any{
|
||||||
|
"component_type": 2,
|
||||||
|
"custom_id": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
|
||||||
|
},
|
||||||
|
Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "https://discord.com/api/v9/interactions"
|
||||||
|
var res InteractionsResult
|
||||||
|
r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(interactionsReq).
|
||||||
|
SetErrorResult(&res).
|
||||||
|
Post(url)
|
||||||
|
if err != nil || r.IsErrorState() {
|
||||||
|
return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
249
api/service/mj/service.go
Normal file
249
api/service/mj/service.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/oss"
|
||||||
|
"chatplus/store"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"chatplus/store/vo"
|
||||||
|
"chatplus/utils"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MJ 绘画服务
|
||||||
|
|
||||||
|
const RunningJobKey = "MidJourney_Running_Job"
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
client *Client
|
||||||
|
taskQueue *store.RedisQueue
|
||||||
|
redis *redis.Client
|
||||||
|
db *gorm.DB
|
||||||
|
uploadManager *oss.UploaderManager
|
||||||
|
Clients *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
|
||||||
|
ChatClients *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
|
||||||
|
proxyURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
|
||||||
|
return &Service{
|
||||||
|
redis: redisCli,
|
||||||
|
db: db,
|
||||||
|
taskQueue: store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
|
||||||
|
client: client,
|
||||||
|
uploadManager: manager,
|
||||||
|
Clients: types.NewLMap[string, *types.WsClient](),
|
||||||
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
|
proxyURL: config.ProxyURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Run() {
|
||||||
|
logger.Info("Starting MidJourney job consumer.")
|
||||||
|
ctx := context.Background()
|
||||||
|
for {
|
||||||
|
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||||
|
if err == nil { // 队列串行执行
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var task types.MjTask
|
||||||
|
err = s.taskQueue.LPop(&task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("taking task with error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Infof("Consuming Task: %+v", task)
|
||||||
|
switch task.Type {
|
||||||
|
case types.TaskImage:
|
||||||
|
err = s.client.Imagine(task.Prompt)
|
||||||
|
break
|
||||||
|
case types.TaskUpscale:
|
||||||
|
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
|
||||||
|
|
||||||
|
break
|
||||||
|
case types.TaskVariation:
|
||||||
|
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("绘画任务执行失败:", err)
|
||||||
|
if task.RetryCount <= 5 {
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
task.RetryCount += 1
|
||||||
|
time.Sleep(time.Second * 3)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新任务的执行状态
|
||||||
|
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||||
|
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||||
|
s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) PushTask(task types.MjTask) {
|
||||||
|
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||||
|
s.taskQueue.RPush(task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Notify(data CBReq) {
|
||||||
|
taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
|
||||||
|
if err != nil { // 过期任务,丢弃
|
||||||
|
logger.Warn("任务已过期:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var task types.MjTask
|
||||||
|
err = utils.JsonDecode(taskString, &task)
|
||||||
|
if err != nil { // 非标准任务,丢弃
|
||||||
|
logger.Warn("任务解析失败:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var job model.MidJourneyJob
|
||||||
|
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
||||||
|
if res.Error == nil && data.Status == Finished {
|
||||||
|
logger.Warn("重复消息:", data.MessageId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.Src == types.TaskSrcImg { // 绘画任务
|
||||||
|
var job model.MidJourneyJob
|
||||||
|
res := s.db.Where("id = ?", task.Id).First(&job)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Warn("非法任务:", res.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
job.MessageId = data.MessageId
|
||||||
|
job.ReferenceId = data.ReferenceId
|
||||||
|
job.Progress = data.Progress
|
||||||
|
job.Prompt = data.Prompt
|
||||||
|
job.Hash = data.Image.Hash
|
||||||
|
|
||||||
|
// 任务完成,将最终的图片下载下来
|
||||||
|
if data.Progress == 100 {
|
||||||
|
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download img: ", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
job.ImgURL = imgURL
|
||||||
|
} else {
|
||||||
|
// 临时图片直接保存,访问的时候使用代理进行转发
|
||||||
|
job.ImgURL = data.Image.URL
|
||||||
|
}
|
||||||
|
res = s.db.Updates(&job)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update job: ", res.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var jobVo vo.MidJourneyJob
|
||||||
|
err := utils.CopyObject(job, &jobVo)
|
||||||
|
if err == nil {
|
||||||
|
if data.Progress < 100 {
|
||||||
|
image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
|
||||||
|
if err == nil {
|
||||||
|
jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送任务到前端
|
||||||
|
client := s.Clients.Get(task.SessionId)
|
||||||
|
if client != nil {
|
||||||
|
utils.ReplyChunkMessage(client, jobVo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if task.Src == types.TaskSrcChat { // 聊天任务
|
||||||
|
wsClient := s.ChatClients.Get(task.SessionId)
|
||||||
|
if data.Status == Finished {
|
||||||
|
if wsClient != nil && data.ReferenceId != "" {
|
||||||
|
content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
|
||||||
|
utils.ReplyMessage(wsClient, content)
|
||||||
|
}
|
||||||
|
// download image
|
||||||
|
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with download image: ", err)
|
||||||
|
if wsClient != nil && data.ReferenceId != "" {
|
||||||
|
content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
|
||||||
|
utils.ReplyMessage(wsClient, content)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := s.db.Begin()
|
||||||
|
data.Image.URL = imgURL
|
||||||
|
message := model.HistoryMessage{
|
||||||
|
UserId: uint(task.UserId),
|
||||||
|
ChatId: task.ChatId,
|
||||||
|
RoleId: uint(task.RoleId),
|
||||||
|
Type: types.MjMsg,
|
||||||
|
Icon: task.Icon,
|
||||||
|
Content: utils.JsonEncode(data),
|
||||||
|
Tokens: 0,
|
||||||
|
UseContext: false,
|
||||||
|
}
|
||||||
|
res = tx.Create(&message)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// save the job
|
||||||
|
job.UserId = task.UserId
|
||||||
|
job.Type = task.Type.String()
|
||||||
|
job.MessageId = data.MessageId
|
||||||
|
job.ReferenceId = data.ReferenceId
|
||||||
|
job.Prompt = data.Prompt
|
||||||
|
job.ImgURL = imgURL
|
||||||
|
job.Progress = data.Progress
|
||||||
|
job.Hash = data.Image.Hash
|
||||||
|
job.CreatedAt = time.Now()
|
||||||
|
res = tx.Create(&job)
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Error("error with update database: ", err)
|
||||||
|
tx.Rollback()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
if wsClient == nil { // 客户端断线,则丢弃
|
||||||
|
logger.Errorf("Client is offline: %+v", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Status == Finished {
|
||||||
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||||
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
|
||||||
|
// 本次绘画完毕,移除客户端
|
||||||
|
s.ChatClients.Delete(task.SessionId)
|
||||||
|
} else {
|
||||||
|
// 使用代理临时转发图片
|
||||||
|
if data.Image.URL != "" {
|
||||||
|
image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
|
||||||
|
if err == nil {
|
||||||
|
data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户剩余绘图次数
|
||||||
|
// TODO: 放大图片是否需要消耗绘图次数?
|
||||||
|
if data.Status == Finished {
|
||||||
|
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
|
||||||
|
// 解除任务锁定
|
||||||
|
s.redis.Del(context.Background(), RunningJobKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
34
api/service/mj/types.go
Normal file
34
api/service/mj/types.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package mj
|
||||||
|
|
||||||
|
const (
|
||||||
|
ApplicationID string = "936929561302675456"
|
||||||
|
SessionID string = "ea8816d857ba9ae2f74c59ae1a953afe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InteractionsRequest struct {
|
||||||
|
Type int `json:"type"`
|
||||||
|
ApplicationID string `json:"application_id"`
|
||||||
|
MessageFlags *int `json:"message_flags,omitempty"`
|
||||||
|
MessageID *string `json:"message_id,omitempty"`
|
||||||
|
GuildID string `json:"guild_id"`
|
||||||
|
ChannelID string `json:"channel_id"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
Nonce string `json:"nonce,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InteractionsResult struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string
|
||||||
|
Error map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type CBReq struct {
|
||||||
|
MessageId string `json:"message_id"`
|
||||||
|
ReferenceId string `json:"reference_id"`
|
||||||
|
Image Image `json:"image"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Status TaskStatus `json:"status"`
|
||||||
|
Progress int `json:"progress"`
|
||||||
|
}
|
@ -1,166 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"chatplus/core/types"
|
|
||||||
logger2 "chatplus/logger"
|
|
||||||
"chatplus/store"
|
|
||||||
"chatplus/store/model"
|
|
||||||
"chatplus/utils"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var logger = logger2.GetLogger()
|
|
||||||
|
|
||||||
// MJ 绘画服务
|
|
||||||
|
|
||||||
const MjRunningJobKey = "MidJourney_Running_Job"
|
|
||||||
|
|
||||||
type MjService struct {
|
|
||||||
config types.ChatPlusExtConfig
|
|
||||||
client *req.Client
|
|
||||||
taskQueue *store.RedisQueue
|
|
||||||
redis *redis.Client
|
|
||||||
db *gorm.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMjService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *MjService {
|
|
||||||
return &MjService{
|
|
||||||
config: appConfig.ExtConfig,
|
|
||||||
redis: client,
|
|
||||||
db: db,
|
|
||||||
taskQueue: store.NewRedisQueue("midjourney_task_queue", client),
|
|
||||||
client: req.C().SetTimeout(30 * time.Second)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *MjService) Run() {
|
|
||||||
logger.Info("Starting MidJourney job consumer.")
|
|
||||||
ctx := context.Background()
|
|
||||||
for {
|
|
||||||
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
|
||||||
if err == nil { // 队列串行执行
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var task types.MjTask
|
|
||||||
err = s.taskQueue.LPop(&task)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("taking task with error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Infof("Consuming Task: %+v", task)
|
|
||||||
switch task.Type {
|
|
||||||
case types.TaskImage:
|
|
||||||
err = s.image(task.Prompt)
|
|
||||||
break
|
|
||||||
case types.TaskUpscale:
|
|
||||||
err = s.upscale(MjUpscaleReq{
|
|
||||||
Index: task.Index,
|
|
||||||
MessageId: task.MessageId,
|
|
||||||
MessageHash: task.MessageHash,
|
|
||||||
})
|
|
||||||
break
|
|
||||||
case types.TaskVariation:
|
|
||||||
err = s.variation(MjVariationReq{
|
|
||||||
Index: task.Index,
|
|
||||||
MessageId: task.MessageId,
|
|
||||||
MessageHash: task.MessageHash,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("绘画任务执行失败:", err)
|
|
||||||
if task.RetryCount <= 5 {
|
|
||||||
s.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
task.RetryCount += 1
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务的执行状态
|
|
||||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
|
||||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
|
||||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *MjService) PushTask(task types.MjTask) {
|
|
||||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
|
||||||
s.taskQueue.RPush(task)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *MjService) image(prompt string) error {
|
|
||||||
logger.Infof("MJ 绘画参数:%+v", prompt)
|
|
||||||
body := map[string]string{"prompt": prompt}
|
|
||||||
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
|
|
||||||
var res types.BizVo
|
|
||||||
r, err := s.client.R().
|
|
||||||
SetHeader("Authorization", s.config.Token).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(body).
|
|
||||||
SetSuccessResult(&res).Post(url)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("%v%v", r.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Code != types.Success {
|
|
||||||
return errors.New(res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type MjUpscaleReq struct {
|
|
||||||
Index int32 `json:"index"`
|
|
||||||
MessageId string `json:"message_id"`
|
|
||||||
MessageHash string `json:"message_hash"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *MjService) upscale(upReq MjUpscaleReq) error {
|
|
||||||
url := fmt.Sprintf("%s/api/mj/upscale", s.config.ApiURL)
|
|
||||||
var res types.BizVo
|
|
||||||
r, err := s.client.R().
|
|
||||||
SetHeader("Authorization", s.config.Token).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(upReq).
|
|
||||||
SetSuccessResult(&res).Post(url)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("%v%v", r.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Code != types.Success {
|
|
||||||
return errors.New(res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type MjVariationReq struct {
|
|
||||||
Index int32 `json:"index"`
|
|
||||||
MessageId string `json:"message_id"`
|
|
||||||
MessageHash string `json:"message_hash"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *MjService) variation(upReq MjVariationReq) error {
|
|
||||||
url := fmt.Sprintf("%s/api/mj/variation", s.config.ApiURL)
|
|
||||||
var res types.BizVo
|
|
||||||
r, err := s.client.R().
|
|
||||||
SetHeader("Authorization", s.config.Token).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(upReq).
|
|
||||||
SetSuccessResult(&res).Post(url)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("%v%v", r.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Code != types.Success {
|
|
||||||
return errors.New(res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
169
api/service/sd/client.go
Normal file
169
api/service/sd/client.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
package sd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chatplus/core/types"
|
||||||
|
"chatplus/utils"
|
||||||
|
"fmt"
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
httpClient *req.Client
|
||||||
|
config *types.StableDiffusionConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSdClient(config *types.AppConfig) *Client {
|
||||||
|
return &Client{
|
||||||
|
config: &config.SdConfig,
|
||||||
|
httpClient: req.C(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Txt2Img(params types.SdTaskParams) error {
|
||||||
|
var data []interface{}
|
||||||
|
err := utils.JsonDecode(Text2ImgParamTemplate, &data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data[ParamKeys["task_id"]] = params.TaskId
|
||||||
|
data[ParamKeys["prompt"]] = params.Prompt
|
||||||
|
data[ParamKeys["negative_prompt"]] = params.NegativePrompt
|
||||||
|
data[ParamKeys["steps"]] = params.Steps
|
||||||
|
data[ParamKeys["sampler"]] = params.Sampler
|
||||||
|
data[ParamKeys["face_fix"]] = params.FaceFix
|
||||||
|
data[ParamKeys["cfg_scale"]] = params.CfgScale
|
||||||
|
data[ParamKeys["seed"]] = params.Seed
|
||||||
|
data[ParamKeys["height"]] = params.Height
|
||||||
|
data[ParamKeys["width"]] = params.Width
|
||||||
|
data[ParamKeys["hd_fix"]] = params.HdFix
|
||||||
|
data[ParamKeys["hd_redraw_rate"]] = params.HdRedrawRate
|
||||||
|
data[ParamKeys["hd_scale"]] = params.HdScale
|
||||||
|
data[ParamKeys["hd_scale_alg"]] = params.HdScaleAlg
|
||||||
|
data[ParamKeys["hd_sample_num"]] = params.HdSampleNum
|
||||||
|
task := TaskInfo{
|
||||||
|
TaskId: params.TaskId,
|
||||||
|
Data: data,
|
||||||
|
EventData: nil,
|
||||||
|
FnIndex: 494,
|
||||||
|
SessionHash: "ycaxgzm9ah",
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c.runTask(task, c.httpClient)
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) runTask(taskInfo TaskInfo, client *req.Client) {
|
||||||
|
body := map[string]any{
|
||||||
|
"data": taskInfo.Data,
|
||||||
|
"event_data": taskInfo.EventData,
|
||||||
|
"fn_index": taskInfo.FnIndex,
|
||||||
|
"session_hash": taskInfo.SessionHash,
|
||||||
|
}
|
||||||
|
|
||||||
|
var result = make(chan CBReq)
|
||||||
|
go func() {
|
||||||
|
var res struct {
|
||||||
|
Data []interface{} `json:"data"`
|
||||||
|
IsGenerating bool `json:"is_generating"`
|
||||||
|
Duration float64 `json:"duration"`
|
||||||
|
AverageDuration float64 `json:"average_duration"`
|
||||||
|
}
|
||||||
|
var cbReq = CBReq{TaskId: taskInfo.TaskId}
|
||||||
|
response, err := client.R().SetBody(body).SetSuccessResult(&res).Post(c.config.ApiURL + "/run/predict")
|
||||||
|
if err != nil {
|
||||||
|
cbReq.Message = "error with send request: " + err.Error()
|
||||||
|
cbReq.Success = false
|
||||||
|
result <- cbReq
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.IsErrorState() {
|
||||||
|
bytes, _ := io.ReadAll(response.Body)
|
||||||
|
cbReq.Message = "error http status code: " + string(bytes)
|
||||||
|
cbReq.Success = false
|
||||||
|
result <- cbReq
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var images []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
IsFile bool `json:"is_file"`
|
||||||
|
}
|
||||||
|
err = utils.ForceCovert(res.Data[0], &images)
|
||||||
|
if err != nil {
|
||||||
|
cbReq.Message = "error with decode image:" + err.Error()
|
||||||
|
cbReq.Success = false
|
||||||
|
result <- cbReq
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var info map[string]any
|
||||||
|
err = utils.JsonDecode(utils.InterfaceToString(res.Data[1]), &info)
|
||||||
|
if err != nil {
|
||||||
|
cbReq.Message = err.Error()
|
||||||
|
cbReq.Success = false
|
||||||
|
result <- cbReq
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//for k, v := range info {
|
||||||
|
// fmt.Println(k, " => ", v)
|
||||||
|
//}
|
||||||
|
cbReq.ImageName = images[0].Name
|
||||||
|
cbReq.Seed = utils.InterfaceToString(info["seed"])
|
||||||
|
cbReq.Success = true
|
||||||
|
cbReq.Progress = 100
|
||||||
|
result <- cbReq
|
||||||
|
close(result)
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case value := <-result:
|
||||||
|
if value.Success {
|
||||||
|
logger.Infof("%s/file=%s", c.config.ApiURL, value.ImageName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
var progressReq = map[string]any{
|
||||||
|
"id_task": taskInfo.TaskId,
|
||||||
|
"id_live_preview": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
var progressRes struct {
|
||||||
|
Active bool `json:"active"`
|
||||||
|
Queued bool `json:"queued"`
|
||||||
|
Completed bool `json:"completed"`
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
Eta float64 `json:"eta"`
|
||||||
|
LivePreview string `json:"live_preview"`
|
||||||
|
IDLivePreview int `json:"id_live_preview"`
|
||||||
|
TextInfo interface{} `json:"textinfo"`
|
||||||
|
}
|
||||||
|
response, err := client.R().SetBody(progressReq).SetSuccessResult(&progressRes).Post(c.config.ApiURL + "/internal/progress")
|
||||||
|
var cbReq = CBReq{TaskId: taskInfo.TaskId, Success: true}
|
||||||
|
if err != nil { // TODO: 这里可以考虑设置失败重试次数
|
||||||
|
logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.IsErrorState() {
|
||||||
|
bytes, _ := io.ReadAll(response.Body)
|
||||||
|
logger.Error(string(bytes))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cbReq.ImageData = progressRes.LivePreview
|
||||||
|
cbReq.Progress = int(progressRes.Progress * 100)
|
||||||
|
fmt.Println("Progress: ", progressRes.Progress)
|
||||||
|
fmt.Println("Image: ", progressRes.LivePreview)
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,45 +1,42 @@
|
|||||||
package service
|
package sd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/service/mj"
|
||||||
"chatplus/store"
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SD 绘画服务
|
// SD 绘画服务
|
||||||
|
|
||||||
const SdRunningJobKey = "StableDiffusion_Running_Job"
|
const RunningJobKey = "StableDiffusion_Running_Job"
|
||||||
|
|
||||||
type SdService struct {
|
type Service struct {
|
||||||
config types.ChatPlusExtConfig
|
|
||||||
client *req.Client
|
|
||||||
taskQueue *store.RedisQueue
|
taskQueue *store.RedisQueue
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
Client *Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSdService(appConfig *types.AppConfig, client *redis.Client, db *gorm.DB) *SdService {
|
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client) *Service {
|
||||||
return &SdService{
|
return &Service{
|
||||||
config: appConfig.ExtConfig,
|
redis: redisCli,
|
||||||
redis: client,
|
|
||||||
db: db,
|
db: db,
|
||||||
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", client),
|
Client: client,
|
||||||
client: req.C().SetTimeout(30 * time.Second)}
|
taskQueue: store.NewRedisQueue("stable_diffusion_task_queue", redisCli),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SdService) Run() {
|
func (s *Service) Run() {
|
||||||
logger.Info("Starting StableDiffusion job consumer.")
|
logger.Info("Starting StableDiffusion job consumer.")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
for {
|
for {
|
||||||
_, err := s.redis.Get(ctx, SdRunningJobKey).Result()
|
_, err := s.redis.Get(ctx, RunningJobKey).Result()
|
||||||
if err == nil { // 队列串行执行
|
if err == nil { // 队列串行执行
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
continue
|
continue
|
||||||
@ -51,7 +48,7 @@ func (s *SdService) Run() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
logger.Infof("Consuming Task: %+v", task)
|
logger.Infof("Consuming Task: %+v", task)
|
||||||
err = s.txt2img(task.Params)
|
err = s.Client.Txt2Img(task.Params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("绘画任务执行失败:", err)
|
logger.Error("绘画任务执行失败:", err)
|
||||||
if task.RetryCount <= 5 {
|
if task.RetryCount <= 5 {
|
||||||
@ -65,31 +62,11 @@ func (s *SdService) Run() {
|
|||||||
// 更新任务的执行状态
|
// 更新任务的执行状态
|
||||||
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
|
||||||
// 锁定任务执行通道,直到任务超时(5分钟)
|
// 锁定任务执行通道,直到任务超时(5分钟)
|
||||||
s.redis.Set(ctx, MjRunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
s.redis.Set(ctx, mj.RunningJobKey, utils.JsonEncode(task), time.Minute*5)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SdService) PushTask(task types.SdTask) {
|
func (s *Service) PushTask(task types.SdTask) {
|
||||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||||
s.taskQueue.RPush(task)
|
s.taskQueue.RPush(task)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SdService) txt2img(params types.SdParams) error {
|
|
||||||
logger.Infof("SD 绘画参数:%+v", params)
|
|
||||||
url := fmt.Sprintf("%s/api/mj/image", s.config.ApiURL)
|
|
||||||
var res types.BizVo
|
|
||||||
r, err := s.client.R().
|
|
||||||
SetHeader("Authorization", s.config.Token).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(params).
|
|
||||||
SetSuccessResult(&res).Post(url)
|
|
||||||
if err != nil || r.IsErrorState() {
|
|
||||||
return fmt.Errorf("%v%v", r.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Code != types.Success {
|
|
||||||
return errors.New(res.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
234
api/service/sd/types.go
Normal file
234
api/service/sd/types.go
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
package sd
|
||||||
|
|
||||||
|
import logger2 "chatplus/logger"
|
||||||
|
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
type TaskInfo struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
EventData interface{} `json:"event_data"`
|
||||||
|
FnIndex int `json:"fn_index"`
|
||||||
|
SessionHash string `json:"session_hash"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CBReq struct {
|
||||||
|
TaskId string
|
||||||
|
ImageName string
|
||||||
|
ImageData string
|
||||||
|
Progress int
|
||||||
|
Seed string
|
||||||
|
Success bool
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
var ParamKeys = map[string]int{
|
||||||
|
"task_id": 0,
|
||||||
|
"prompt": 1,
|
||||||
|
"negative_prompt": 2,
|
||||||
|
"steps": 4,
|
||||||
|
"sampler": 5,
|
||||||
|
"face_fix": 6,
|
||||||
|
"cfg_scale": 10,
|
||||||
|
"seed": 11,
|
||||||
|
"height": 17,
|
||||||
|
"width": 18,
|
||||||
|
"hd_fix": 19,
|
||||||
|
"hd_redraw_rate": 20, //高清修复重绘幅度
|
||||||
|
"hd_scale": 21, // 高清修复放大倍数
|
||||||
|
"hd_scale_alg": 22, // 高清修复放大算法
|
||||||
|
"hd_sample_num": 23, // 高清修复采样次数
|
||||||
|
}
|
||||||
|
|
||||||
|
const Text2ImgParamTemplate = `[
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
30,
|
||||||
|
"DPM++ SDE Karras",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
7.5,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
true,
|
||||||
|
0.7,
|
||||||
|
2,
|
||||||
|
"Latent",
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
"Use same sampler",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
"None",
|
||||||
|
false,
|
||||||
|
"MultiDiffusion",
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
96,
|
||||||
|
96,
|
||||||
|
48,
|
||||||
|
4,
|
||||||
|
"None",
|
||||||
|
2,
|
||||||
|
false,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
64,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
0.4,
|
||||||
|
0.4,
|
||||||
|
0.2,
|
||||||
|
0.2,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Background",
|
||||||
|
0.2,
|
||||||
|
-1,
|
||||||
|
false,
|
||||||
|
3072,
|
||||||
|
192,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
"",
|
||||||
|
0.5,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
"",
|
||||||
|
"Lerp",
|
||||||
|
false,
|
||||||
|
"🔄",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
"positive",
|
||||||
|
"comma",
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
"",
|
||||||
|
"Seed",
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
"Nothing",
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
"Nothing",
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
0,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
50
|
||||||
|
]`
|
87
api/service/wx/bot.go
Normal file
87
api/service/wx/bot.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package wx
|
||||||
|
|
||||||
|
import (
|
||||||
|
logger2 "chatplus/logger"
|
||||||
|
"chatplus/store/model"
|
||||||
|
"github.com/eatmoreapple/openwechat"
|
||||||
|
"github.com/skip2/go-qrcode"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 微信收款机器人
|
||||||
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
|
type Bot struct {
|
||||||
|
bot *openwechat.Bot
|
||||||
|
token string
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWeChatBot(db *gorm.DB) *Bot {
|
||||||
|
bot := openwechat.DefaultBot(openwechat.Desktop)
|
||||||
|
return &Bot{
|
||||||
|
bot: bot,
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bot) Run() error {
|
||||||
|
logger.Info("Starting WeChat Bot...")
|
||||||
|
|
||||||
|
// set message handler
|
||||||
|
b.bot.MessageHandler = func(msg *openwechat.Message) {
|
||||||
|
b.messageHandler(msg)
|
||||||
|
}
|
||||||
|
// scan code login callback
|
||||||
|
b.bot.UUIDCallback = b.qrCodeCallBack
|
||||||
|
|
||||||
|
err := b.bot.Login()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("微信登录成功!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// message handler
|
||||||
|
func (b *Bot) messageHandler(msg *openwechat.Message) {
|
||||||
|
sender, err := msg.Sender()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只处理微信支付的推送消息
|
||||||
|
if sender.NickName == "微信支付" ||
|
||||||
|
msg.MsgType == openwechat.MsgTypeApp ||
|
||||||
|
msg.AppMsgType == openwechat.AppMsgTypeUrl {
|
||||||
|
// 解析支付金额
|
||||||
|
message, err := parseTransactionMessage(msg.Content)
|
||||||
|
if err == nil {
|
||||||
|
transaction := extractTransaction(message)
|
||||||
|
logger.Infof("解析到收款信息:%+v", transaction)
|
||||||
|
var item model.Reward
|
||||||
|
res := b.db.Where("tx_id = ?", transaction.TransId).First(&item)
|
||||||
|
if res.Error == nil {
|
||||||
|
logger.Error("当前交易 ID 己经存在!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res = b.db.Create(&model.Reward{
|
||||||
|
TxId: transaction.TransId,
|
||||||
|
Amount: transaction.Amount,
|
||||||
|
Remark: transaction.Remark,
|
||||||
|
Status: false,
|
||||||
|
})
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.Errorf("交易保存失败: %v", res.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bot) qrCodeCallBack(uuid string) {
|
||||||
|
logger.Info("请使用微信扫描下面二维码登录")
|
||||||
|
q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium)
|
||||||
|
logger.Info(q.ToString(true))
|
||||||
|
}
|
68
api/service/wx/tranaction.go
Normal file
68
api/service/wx/tranaction.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package wx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message 转账消息
|
||||||
|
type Message struct {
|
||||||
|
XMLName xml.Name `xml:"msg"`
|
||||||
|
AppMsg struct {
|
||||||
|
Des string `xml:"des"`
|
||||||
|
Url string `xml:"url"`
|
||||||
|
} `xml:"appmsg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transaction 解析后的交易信息
|
||||||
|
type Transaction struct {
|
||||||
|
TransId string `json:"trans_id"` // 微信转账交易 ID
|
||||||
|
Amount float64 `json:"amount"` // 微信转账交易金额
|
||||||
|
Remark string `json:"remark"` // 转账备注
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析微信转账消息
|
||||||
|
func parseTransactionMessage(xmlData string) (*Message, error) {
|
||||||
|
var msg Message
|
||||||
|
if err := xml.Unmarshal([]byte(xmlData), &msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导出交易信息
|
||||||
|
func extractTransaction(message *Message) Transaction {
|
||||||
|
var tx = Transaction{}
|
||||||
|
// 导出交易金额和备注
|
||||||
|
lines := strings.Split(message.AppMsg.Des, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 解析收款金额
|
||||||
|
prefix := "收款金额¥"
|
||||||
|
if strings.HasPrefix(line, prefix) {
|
||||||
|
if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil {
|
||||||
|
tx.Amount = value
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 解析收款备注
|
||||||
|
prefix = "付款方备注"
|
||||||
|
if strings.HasPrefix(line, prefix) {
|
||||||
|
tx.Remark = line[len(prefix):]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析交易 ID
|
||||||
|
index := strings.Index(message.AppMsg.Url, "trans_id=")
|
||||||
|
if index != -1 {
|
||||||
|
end := strings.LastIndex(message.AppMsg.Url, "&")
|
||||||
|
tx.TransId = strings.TrimSpace(message.AppMsg.Url[index+9 : end])
|
||||||
|
}
|
||||||
|
return tx
|
||||||
|
}
|
@ -11,7 +11,7 @@ type SdJob struct {
|
|||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
TaskId string `json:"task_id"`
|
TaskId string `json:"task_id"`
|
||||||
ImgURL string `json:"img_url"`
|
ImgURL string `json:"img_url"`
|
||||||
Params types.SdParams `json:"params"`
|
Params types.SdTaskParams `json:"params"`
|
||||||
Progress int `json:"progress"`
|
Progress int `json:"progress"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
@ -138,3 +138,15 @@ func IntValue(str string, defaultValue int) int {
|
|||||||
}
|
}
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ForceCovert(src any, dst interface{}) error {
|
||||||
|
bytes, err := json.Marshal(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(bytes, dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user