mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 17:26:38 +08:00
feat: add midjourney message receive handler
This commit is contained in:
parent
2165ba3406
commit
373370fde5
@ -34,7 +34,7 @@ type AppServer struct {
|
|||||||
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合
|
||||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||||
Functions map[string]function.Function
|
Functions map[string]function.Function
|
||||||
MjTasks *types.LMap[string, types.MjTask]
|
MjTaskClients *types.LMap[string, *types.WsClient]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
|
func NewServer(appConfig *types.AppConfig, functions map[string]function.Function) *AppServer {
|
||||||
@ -48,7 +48,7 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio
|
|||||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||||
MjTasks: types.NewLMap[string, types.MjTask](),
|
MjTaskClients: types.NewLMap[string, *types.WsClient](),
|
||||||
Functions: functions,
|
Functions: functions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,8 +12,8 @@ type BizVo struct {
|
|||||||
|
|
||||||
// WsMessage Websocket message
|
// WsMessage Websocket message
|
||||||
type WsMessage struct {
|
type WsMessage struct {
|
||||||
Type WsMsgType `json:"type"` // 消息类别,start, end
|
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
||||||
Content string `json:"content"`
|
Content interface{} `json:"content"`
|
||||||
}
|
}
|
||||||
type WsMsgType string
|
type WsMsgType string
|
||||||
|
|
||||||
@ -21,6 +21,7 @@ const (
|
|||||||
WsStart = WsMsgType("start")
|
WsStart = WsMsgType("start")
|
||||||
WsMiddle = WsMsgType("middle")
|
WsMiddle = WsMsgType("middle")
|
||||||
WsEnd = WsMsgType("end")
|
WsEnd = WsMsgType("end")
|
||||||
|
WsImg = WsMsgType("img")
|
||||||
)
|
)
|
||||||
|
|
||||||
type BizCode int
|
type BizCode int
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"chatplus/core"
|
"chatplus/core"
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
|
"chatplus/store"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
@ -26,14 +27,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
||||||
|
const TaskStorePrefix = "/tasks/"
|
||||||
|
|
||||||
type ChatHandler struct {
|
type ChatHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
leveldb *store.LevelDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ChatHandler {
|
||||||
handler := ChatHandler{db: db}
|
handler := ChatHandler{db: db, leveldb: levelDB}
|
||||||
handler.App = app
|
handler.App = app
|
||||||
return &handler
|
return &handler
|
||||||
}
|
}
|
||||||
@ -133,7 +136,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
// 将消息发送给 ChatGPT 并获取结果,通过 WebSocket 推送到客户端
|
||||||
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws types.Client) error {
|
func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
@ -340,13 +343,18 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
|||||||
if functionName == types.FuncMidJourney {
|
if functionName == types.FuncMidJourney {
|
||||||
key := utils.Sha256(data)
|
key := utils.Sha256(data)
|
||||||
// add task for MidJourney
|
// add task for MidJourney
|
||||||
h.App.MjTasks.Put(key, types.MjTask{
|
h.App.MjTaskClients.Put(key, ws)
|
||||||
|
task := types.MjTask{
|
||||||
UserId: userVo.Id,
|
UserId: userVo.Id,
|
||||||
RoleId: role.Id,
|
RoleId: role.Id,
|
||||||
Icon: role.Icon,
|
Icon: role.Icon,
|
||||||
Client: ws,
|
Client: ws,
|
||||||
ChatId: session.ChatId,
|
ChatId: session.ChatId,
|
||||||
})
|
}
|
||||||
|
err := h.leveldb.Put(TaskStorePrefix+key, task)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with store MidJourney task: ", err)
|
||||||
|
}
|
||||||
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ type Image struct {
|
|||||||
Width int `json:"width"`
|
Width int `json:"width"`
|
||||||
Height int `json:"height"`
|
Height int `json:"height"`
|
||||||
Size int `json:"size"`
|
Size int `json:"size"`
|
||||||
|
Hash string `json:"hash"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MidJourneyHandler struct {
|
type MidJourneyHandler struct {
|
||||||
@ -44,18 +45,30 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
Image Image `json:"image"`
|
Type string `json:"type"`
|
||||||
Content string `json:"content"`
|
MessageId string `json:"message_id"`
|
||||||
Status TaskStatus `json:"status"`
|
Image Image `json:"image"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Status TaskStatus `json:"status"`
|
||||||
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
|
||||||
resp.ERROR(c, types.InvalidArgs)
|
resp.ERROR(c, types.InvalidArgs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
key := utils.Sha256(data.Prompt)
|
||||||
|
data.Key = key
|
||||||
|
// TODO: 如果绘画任务完成了则将该消息保存到当前会话的聊天历史记录
|
||||||
|
|
||||||
sessionId := "u7blnft9zqisyrwidjb22j6b78iqc30lv9jtud3k9o"
|
wsClient := h.App.MjTaskClients.Get(key)
|
||||||
wsClient := h.App.ChatClients.Get(sessionId)
|
if wsClient == nil { // 客户端断线,则丢弃
|
||||||
utils.ReplyMessage(wsClient, "")
|
resp.SUCCESS(c)
|
||||||
logger.Infof("Data: %+v", data)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送消息到客户端
|
||||||
|
// TODO: 增加绘画消息类型
|
||||||
|
utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsImg, Content: data})
|
||||||
resp.ERROR(c, "Error with CallBack")
|
resp.ERROR(c, "Error with CallBack")
|
||||||
}
|
}
|
||||||
|
@ -14,13 +14,13 @@ const CodeStorePrefix = "/verify/codes/"
|
|||||||
|
|
||||||
type SmsHandler struct {
|
type SmsHandler struct {
|
||||||
BaseHandler
|
BaseHandler
|
||||||
db *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
sms *service.AliYunSmsService
|
sms *service.AliYunSmsService
|
||||||
captcha *service.CaptchaService
|
captcha *service.CaptchaService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
|
func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
|
||||||
handler := &SmsHandler{db: db, sms: sms, captcha: captcha}
|
handler := &SmsHandler{leveldb: db, sms: sms, captcha: captcha}
|
||||||
handler.App = app
|
handler.App = app
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 存储验证码,等待后面注册验证
|
// 存储验证码,等待后面注册验证
|
||||||
err = h.db.Put(CodeStorePrefix+data.Mobile, code)
|
err = h.leveldb.Put(CodeStorePrefix+data.Mobile, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
resp.ERROR(c, "验证码保存失败")
|
resp.ERROR(c, "验证码保存失败")
|
||||||
return
|
return
|
||||||
|
@ -22,11 +22,11 @@ type UserHandler struct {
|
|||||||
BaseHandler
|
BaseHandler
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
searcher *xdb.Searcher
|
searcher *xdb.Searcher
|
||||||
levelDB *store.LevelDB
|
leveldb *store.LevelDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler {
|
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, levelDB *store.LevelDB) *UserHandler {
|
||||||
handler := &UserHandler{db: db, searcher: searcher, levelDB: levelDB}
|
handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB}
|
||||||
handler.App = app
|
handler.App = app
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
@ -60,7 +60,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
key := CodeStorePrefix + data.Mobile
|
key := CodeStorePrefix + data.Mobile
|
||||||
if h.App.SysConfig.EnabledMsgService {
|
if h.App.SysConfig.EnabledMsgService {
|
||||||
var code int
|
var code int
|
||||||
err := h.levelDB.Get(key, &code)
|
err := h.leveldb.Get(key, &code)
|
||||||
if err != nil || code != data.Code {
|
if err != nil || code != data.Code {
|
||||||
logger.Info(code)
|
logger.Info(code)
|
||||||
resp.ERROR(c, "短信验证码错误")
|
resp.ERROR(c, "短信验证码错误")
|
||||||
@ -118,7 +118,7 @@ func (h *UserHandler) Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if h.App.SysConfig.EnabledMsgService {
|
if h.App.SysConfig.EnabledMsgService {
|
||||||
_ = h.levelDB.Delete(key) // 注册成功,删除短信验证码
|
_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码
|
||||||
}
|
}
|
||||||
resp.SUCCESS(c, user)
|
resp.SUCCESS(c, user)
|
||||||
}
|
}
|
||||||
@ -366,7 +366,7 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
|
|||||||
// 检查验证码
|
// 检查验证码
|
||||||
key := CodeStorePrefix + data.Mobile
|
key := CodeStorePrefix + data.Mobile
|
||||||
var code int
|
var code int
|
||||||
err := h.levelDB.Get(key, &code)
|
err := h.leveldb.Get(key, &code)
|
||||||
if err != nil || code != data.Code {
|
if err != nil || code != data.Code {
|
||||||
resp.ERROR(c, "短信验证码错误")
|
resp.ERROR(c, "短信验证码错误")
|
||||||
return
|
return
|
||||||
@ -384,6 +384,6 @@ func (h *UserHandler) BindMobile(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = h.levelDB.Delete(key) // 删除短信验证码
|
_ = h.leveldb.Delete(key) // 删除短信验证码
|
||||||
resp.SUCCESS(c)
|
resp.SUCCESS(c)
|
||||||
}
|
}
|
||||||
|
@ -36,12 +36,13 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
|||||||
delete(params, "ar")
|
delete(params, "ar")
|
||||||
}
|
}
|
||||||
prompt = prompt + " --niji 5"
|
prompt = prompt + " --niji 5"
|
||||||
|
url := fmt.Sprintf("%s/api/mj/image", f.config.ApiURL)
|
||||||
var res types.BizVo
|
var res types.BizVo
|
||||||
r, err := f.client.R().
|
r, err := f.client.R().
|
||||||
SetHeader("Authorization", f.config.Token).
|
SetHeader("Authorization", f.config.Token).
|
||||||
SetHeader("Content-Type", "application/json").
|
SetHeader("Content-Type", "application/json").
|
||||||
SetBody(params).
|
SetBody(params).
|
||||||
SetSuccessResult(&res).Post(f.config.ApiURL)
|
SetSuccessResult(&res).Post(url)
|
||||||
if err != nil || r.IsErrorState() {
|
if err != nil || r.IsErrorState() {
|
||||||
return "", fmt.Errorf("%v%v", r.String(), err)
|
return "", fmt.Errorf("%v%v", r.String(), err)
|
||||||
}
|
}
|
||||||
|
@ -9,20 +9,20 @@ import (
|
|||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
// ReplyChunkMessage 回复客户片段端消息
|
// ReplyChunkMessage 回复客户片段端消息
|
||||||
func ReplyChunkMessage(client types.Client, message types.WsMessage) {
|
func ReplyChunkMessage(client *types.WsClient, message types.WsMessage) {
|
||||||
msg, err := json.Marshal(message)
|
msg, err := json.Marshal(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = client.(*types.WsClient).Send(msg)
|
err = client.Send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error for reply message: %v", err.Error())
|
logger.Errorf("Error for reply message: %v", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyMessage 回复客户端一条完整的消息
|
// ReplyMessage 回复客户端一条完整的消息
|
||||||
func ReplyMessage(ws types.Client, message string) {
|
func ReplyMessage(ws *types.WsClient, message interface{}) {
|
||||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message})
|
||||||
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd})
|
||||||
|
Loading…
Reference in New Issue
Block a user