feat: add midjourney message receive handler

This commit is contained in:
RockYang
2023-08-14 07:09:52 +08:00
parent 2165ba3406
commit 373370fde5
8 changed files with 54 additions and 31 deletions

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"chatplus/core"
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@@ -26,14 +27,16 @@ import (
)
const ErrorMsg = "抱歉AI 助手开小差了,请稍后再试。"
const TaskStorePrefix = "/tasks/"
type ChatHandler struct {
BaseHandler
db *gorm.DB
db *gorm.DB
leveldb *store.LevelDB
}
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db}
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ChatHandler {
handler := ChatHandler{db: db, leveldb: levelDB}
handler.App = app
return &handler
}
@@ -133,7 +136,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var user model.User
@@ -340,13 +343,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
// add task for MidJourney
h.App.MjTasks.Put(key, types.MjTask{
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: role.Icon,
Client: ws,
ChatId: session.ChatId,
})
}
err := h.leveldb.Put(TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
}

View File

@@ -24,6 +24,7 @@ type Image struct {
Width int `json:"width"`
Height int `json:"height"`
Size int `json:"size"`
Hash string `json:"hash"`
}
type MidJourneyHandler struct {
@@ -44,18 +45,30 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
}
var data struct {
Image Image `json:"image"`
Content string `json:"content"`
Status TaskStatus `json:"status"`
Type string `json:"type"`
MessageId string `json:"message_id"`
Image Image `json:"image"`
Content string `json:"content"`
Prompt string `json:"prompt"`
Status TaskStatus `json:"status"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&data); err != nil {
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
key := utils.Sha256(data.Prompt)
data.Key = key
// TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录
sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o"
wsClient := h.App.ChatClients.Get(sessionId)
utils.ReplyMessage(wsClient, "![](https://cdn.discordapp.com/attachments/1138713254718361633/1139482452579070053/lal603743923_A_Chinese_girl_walking_barefoot_on_the_beach_weari_df8b6dc0-3b13-478c-8dbb-983015d21661.png)")
logger.Infof("Data: %+v", data)
wsClient := h.App.MjTaskClients.Get(key)
if wsClient == nil { // 客户端断线,则丢弃
resp.SUCCESS(c)
return
}
// 推送消息到客户端
// TODO: 增加绘画消息类型
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data})
resp.ERROR(c, "Error with CallBack")
}

View File

@@ -14,13 +14,13 @@ const CodeStorePrefix = "/verify/codes/"
type SmsHandler struct {
BaseHandler
db *store.LevelDB
leveldb *store.LevelDB
sms *service.AliYunSmsService
captcha *service.CaptchaService
}
func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
handler := &SmsHandler{db: db, sms: sms, captcha: captcha}
handler := &SmsHandler{leveldb: db, sms: sms, captcha: captcha}
handler.App = app
return handler
}
@@ -50,7 +50,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
}
// 存储验证码,等待后面注册验证
err = h.db.Put(CodeStorePrefix+data.Mobile, code)
err = h.leveldb.Put(CodeStorePrefix+data.Mobile, code)
if err != nil {
resp.ERROR(c, "验证码保存失败")
return

View File

@@ -22,11 +22,11 @@ type UserHandler struct {
BaseHandler
db *gorm.DB
searcher *xdb.Searcher
levelDB *store.LevelDB
leveldb *store.LevelDB
}
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher, levelDB: levelDB}
handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB}
handler.App = app
return handler
}
@@ -60,7 +60,7 @@ func (h *UserHandler) Register(c *gin.Context) {
key := CodeStorePrefix + data.Mobile
if h.App.SysConfig.EnabledMsgService {
var code int
err := h.levelDB.Get(key, &code)
err := h.leveldb.Get(key, &code)
if err != nil || code != data.Code {
logger.Info(code)
resp.ERROR(c, "短信验证码错误")
@@ -118,7 +118,7 @@ func (h *UserHandler) Register(c *gin.Context) {
}
if h.App.SysConfig.EnabledMsgService {
_ = h.levelDB.Delete(key) // 注册成功,删除短信验证码
_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码
}
resp.SUCCESS(c, user)
}
@@ -366,7 +366,7 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
// 检查验证码
key := CodeStorePrefix + data.Mobile
var code int
err := h.levelDB.Get(key, &code)
err := h.leveldb.Get(key, &code)
if err != nil || code != data.Code {
resp.ERROR(c, "短信验证码错误")
return
@@ -384,6 +384,6 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
return
}
_ = h.levelDB.Delete(key) // 删除短信验证码
_ = h.leveldb.Delete(key) // 删除短信验证码
resp.SUCCESS(c)
}