From 09f44e6d9bc63b334cfc835178029007cefc5c1a Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 22 Jul 2024 17:54:09 +0800 Subject: [PATCH] optimize foot copyright snaps --- CHANGELOG.md | 3 +- api/core/types/task.go | 13 +++++ api/handler/admin/chat_model_handler.go | 35 +++++++----- api/handler/admin/user_handler.go | 9 ++- api/handler/chat_model_handler.go | 7 ++- api/handler/suno_handler.go | 63 +++++++++++++++++++++ api/service/mj/service.go | 7 +++ api/service/suno/service.go | 66 ++++++++++++++++++++++ api/store/vo/chat_model.go | 2 +- web/src/components/FooterBar.vue | 17 +++--- web/src/components/ui/BlackInput.vue | 75 ++++++++++++++++--------- web/src/views/Suno.vue | 15 ++++- web/src/views/admin/ApiKey.vue | 16 ++++-- web/src/views/admin/ChatModel.vue | 7 +-- 14 files changed, 266 insertions(+), 69 deletions(-) create mode 100644 api/handler/suno_handler.go create mode 100644 api/service/suno/service.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0af817fc..abcec219 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ * 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用 * 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息 * 功能新增:允许在管理后台设置首页显示的导航菜单 -* 功能新增:增加 Suno 文生音乐页面功能 +* 功能新增:增加 Suno 文生歌曲功能 +* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加 ## v4.1.0 * bug修复:修复移动端修改聊天标题不生效的问题 diff --git a/api/core/types/task.go b/api/core/types/task.go index 129a83b4..552b7666 100644 --- a/api/core/types/task.go +++ b/api/core/types/task.go @@ -78,3 +78,16 @@ type DallTask struct { Power int `json:"power"` } + +type SunoTask struct { + Id int `json:"id"` + UserId string `json:"user_id"` + Type int `json:"type"` + TaskId string `json:"task_id"` + Title string `json:"title"` + ReferenceId string `json:"reference_id"` + Prompt string `json:"prompt"` + Tags string `json:"tags"` + Instrumental bool `json:"instrumental"` // 是否纯音乐 + ExtendSecs int `json:"extend_secs"` // 延长秒杀 +} diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go index 3a1e6328..f4a32444 100644 --- a/api/handler/admin/chat_model_handler.go +++ b/api/handler/admin/chat_model_handler.go @@ -49,28 +49,33 @@ func (h *ChatModelHandler) Save(c *gin.Context) { return } - item := model.ChatModel{ - Platform: data.Platform, - Name: data.Name, - Value: data.Value, - Enabled: data.Enabled, - Open: data.Open, - MaxTokens: data.MaxTokens, - MaxContext: data.MaxContext, - Temperature: data.Temperature, - KeyId: data.KeyId, - Power: data.Power} + item := model.ChatModel{} + // 更新 + if data.Id > 0 { + h.DB.Where("id", data.Id).First(&item) + } + + item.Name = data.Name + item.Value = data.Value + item.Enabled = data.Enabled + item.SortNum = data.SortNum + item.Open = data.Open + item.Platform = data.Platform + item.Power = data.Power + item.MaxTokens = data.MaxTokens + item.MaxContext = data.MaxContext + item.Temperature = data.Temperature + item.KeyId = data.KeyId + var res *gorm.DB if data.Id > 0 { - item.Id = data.Id - item.SortNum = data.SortNum - res = h.DB.Select("*").Omit("created_at").Updates(&item) + res = h.DB.Updates(&item) } else { res = h.DB.Create(&item) } if res.Error != nil { logger.Error("error with update database:", res.Error) - resp.ERROR(c, "更新数据库失败!") + resp.ERROR(c, res.Error.Error()) return } diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go index 34747d77..c2386efc 100644 --- a/api/handler/admin/user_handler.go +++ b/api/handler/admin/user_handler.go @@ -112,7 +112,7 @@ func (h *UserHandler) Save(c *gin.Context) { res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user) if res.Error != nil { logger.Error("error with update database:", res.Error) - resp.ERROR(c, "更新数据库失败!") + resp.ERROR(c, res.Error.Error()) return } // 记录算力日志 @@ -136,6 +136,13 @@ func (h *UserHandler) Save(c *gin.Context) { }) } } else { + // 检查用户是否已经存在 + h.DB.Where("username", data.Username).First(&user) + if user.Id > 0 { + resp.ERROR(c, "用户名已存在") + return + } + salt := utils.RandString(8) u := model.User{ Username: data.Username, diff --git a/api/handler/chat_model_handler.go b/api/handler/chat_model_handler.go index 555de7c8..1b74f348 100644 --- a/api/handler/chat_model_handler.go +++ b/api/handler/chat_model_handler.go @@ -31,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) { var items []model.ChatModel var chatModels = make([]vo.ChatModel, 0) var res *gorm.DB + session := h.DB.Session(&gorm.Session{}).Where("enabled", true) + t := c.Query("type") + if t != "" { + session = session.Where("type", t) + } // 如果用户没有登录,则加载所有开放模型 if !h.IsLogin(c) { - res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items) + res = session.Where("open", true).Order("sort_num ASC").Find(&items) } else { user, _ := h.GetLoginUser(c) var models []int diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go new file mode 100644 index 00000000..5419ddea --- /dev/null +++ b/api/handler/suno_handler.go @@ -0,0 +1,63 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type SunoHandler struct { + BaseHandler +} + +func NewSunoHandler(app *core.AppServer, db *gorm.DB) *SunoHandler { + return &SunoHandler{ + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *SunoHandler) Client(c *gin.Context) { + //ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + //if err != nil { + // logger.Error(err) + // c.Abort() + // return + //} + // + //userId := h.GetInt(c, "user_id", 0) + //if userId == 0 { + // logger.Info("Invalid user ID") + // c.Abort() + // return + //} + // + ////client := types.NewWsClient(ws) + //logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + +func (h *SunoHandler) Create(c *gin.Context) { + +} + +func (h *SunoHandler) List(c *gin.Context) { + +} + +func (h *SunoHandler) Remove(c *gin.Context) { + +} + +func (h *SunoHandler) Publish(c *gin.Context) { + +} diff --git a/api/service/mj/service.go b/api/service/mj/service.go index cba2b396..2218c851 100644 --- a/api/service/mj/service.go +++ b/api/service/mj/service.go @@ -29,6 +29,7 @@ type Service struct { notifyQueue *store.RedisQueue db *gorm.DB running bool + retryCount map[uint]int } func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { @@ -39,6 +40,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red notifyQueue: notifyQueue, Client: cli, running: true, + retryCount: make(map[uint]int), } } @@ -57,8 +59,13 @@ func (s *Service) Run() { // 如果配置了多个中转平台的 API KEY // U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表 if task.ChannelId != "" && task.ChannelId != s.Name { + if s.retryCount[task.Id] > 5 { + s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{}) + continue + } logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) s.taskQueue.RPush(task) + s.retryCount[task.Id]++ time.Sleep(time.Second) continue } diff --git a/api/service/suno/service.go b/api/service/suno/service.go new file mode 100644 index 00000000..8f28302a --- /dev/null +++ b/api/service/suno/service.go @@ -0,0 +1,66 @@ +package dalle + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service/oss" + "geekai/store" + "github.com/go-redis/redis/v8" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli), + Clients: types.NewLMap[uint, *types.WsClient](), + uploadManager: manager, + } +} + +func (s *Service) PushTask(task types.SunoTask) { + logger.Infof("add a new Suno task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + logger.Info("Starting Suno job consumer...") + go func() { + for { + var task types.SunoTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + } + }() +} + +func (s *Service) Create(task types.SunoTask) { + +} diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go index 4fb21051..bc98b626 100644 --- a/api/store/vo/chat_model.go +++ b/api/store/vo/chat_model.go @@ -12,6 +12,6 @@ type ChatModel struct { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 - KeyId int `json:"key_id"` + KeyId int `json:"key_id,omitempty"` KeyName string `json:"key_name"` } diff --git a/web/src/components/FooterBar.vue b/web/src/components/FooterBar.vue index c30955ca..578f27e9 100644 --- a/web/src/components/FooterBar.vue +++ b/web/src/components/FooterBar.vue @@ -1,15 +1,12 @@ +
+ {{ tag }} +
- +
@@ -64,7 +74,7 @@
- +
@@ -180,6 +190,7 @@ const models = ref([ {label: "v3.0", value: "chirp-v3-0"}, {label: "v3.5", value:"chirp-v3-5"} ]) +const tags = ref([]) const data = ref({ model: "chirp-v3-0", tags: "", diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index 11e81a2f..08bdb08c 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -104,10 +104,10 @@ - - - {{ - item.name + + + {{ + item.label }} @@ -159,13 +159,17 @@ const rules = reactive({ type: [{required: true, message: '请选择用途', trigger: 'change',}], value: [{required: true, message: '请输入 API KEY 值', trigger: 'change',}] }) + const loading = ref(true) const formRef = ref(null) const title = ref("") const platforms = ref([]) const types = ref([ - {name: "聊天", value: "chat"}, - {name: "绘画", value: "img"}, + {label: "对话", value:"chat"}, + {label: "Midjourney", value:"mj"}, + {label: "DALL-E", value:"dall"}, + {label: "Suno文生歌", value:"suno"}, + {label: "Luma视频", value:"luma"}, ]) diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue index 4a718615..1c3dc1f9 100644 --- a/web/src/views/admin/ChatModel.vue +++ b/web/src/views/admin/ChatModel.vue @@ -46,11 +46,6 @@ - - - - -