mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-09 02:33:42 +08:00
feat: chat chrawing function is refactored
This commit is contained in:
@@ -48,6 +48,7 @@ func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
|
||||
f.service.PushTask(service.MjTask{
|
||||
Id: utils.InterfaceToString(params["session_id"]),
|
||||
Src: service.TaskSrcChat,
|
||||
Type: service.Image,
|
||||
Prompt: prompt,
|
||||
UserId: utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
|
||||
RoleId: utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
|
||||
|
||||
@@ -3,6 +3,7 @@ package function
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service"
|
||||
)
|
||||
|
||||
type Function interface {
|
||||
@@ -28,11 +29,11 @@ type dataItem struct {
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func NewFunctions(config *types.AppConfig) map[string]Function {
|
||||
func NewFunctions(config *types.AppConfig, mjService *service.MjService) map[string]Function {
|
||||
return map[string]Function{
|
||||
types.FuncZaoBao: NewZaoBao(config.ApiConfig),
|
||||
types.FuncWeibo: NewWeiboHot(config.ApiConfig),
|
||||
types.FuncHeadLine: NewHeadLines(config.ApiConfig),
|
||||
types.FuncMidJourney: NewMidJourneyFunc(config.ExtConfig),
|
||||
types.FuncMidJourney: NewMidJourneyFunc(mjService),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,19 +56,21 @@ type MjService struct {
|
||||
redis *redis.Client
|
||||
}
|
||||
|
||||
func NewMjService(config types.ChatPlusExtConfig, client *redis.Client) *MjService {
|
||||
func NewMjService(appConfig *types.AppConfig, client *redis.Client) *MjService {
|
||||
return &MjService{
|
||||
config: config,
|
||||
config: appConfig.ExtConfig,
|
||||
redis: client,
|
||||
taskQueue: store.NewRedisQueue("midjourney_task_queue", client),
|
||||
client: req.C().SetTimeout(30 * time.Second)}
|
||||
}
|
||||
|
||||
func (s *MjService) Run() {
|
||||
logger.Info("Starting MidJourney job consumer.")
|
||||
ctx := context.Background()
|
||||
for {
|
||||
_, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||
if err == nil { // a task is running, waiting for finish
|
||||
t, err := s.redis.Get(ctx, MjRunningJobKey).Result()
|
||||
if err == nil {
|
||||
logger.Infof("An task is not finished: %s", t)
|
||||
time.Sleep(time.Second * 3)
|
||||
continue
|
||||
}
|
||||
@@ -78,7 +80,7 @@ func (s *MjService) Run() {
|
||||
logger.Errorf("taking task with error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Infof("Consuming Task: %+v", task)
|
||||
switch task.Type {
|
||||
case Image:
|
||||
err = s.image(task.Prompt)
|
||||
@@ -98,11 +100,11 @@ func (s *MjService) Run() {
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
logger.Error("绘画任务执行失败:", err)
|
||||
if task.RetryCount > 5 {
|
||||
continue
|
||||
}
|
||||
task.RetryCount += 1
|
||||
time.Sleep(time.Second)
|
||||
s.taskQueue.RPush(task)
|
||||
// TODO: 执行失败通知聊天客户端
|
||||
continue
|
||||
@@ -114,6 +116,7 @@ func (s *MjService) Run() {
|
||||
}
|
||||
|
||||
func (s *MjService) PushTask(task MjTask) {
|
||||
logger.Infof("add a new MidJourney Task: %+v", task)
|
||||
s.taskQueue.RPush(task)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user