diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0af817fc..abcec219 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,7 +7,8 @@
* 功能优化:在应用列表页面,无需先添加模型到用户工作区,可以直接使用
* 功能新增:MJ 绘图失败的任务不会自动删除,而是会在列表页显示失败详细错误信息
* 功能新增:允许在管理后台设置首页显示的导航菜单
-* 功能新增:增加 Suno 文生音乐页面功能
+* 功能新增:增加 Suno 文生歌曲功能
+* 功能优化:移除多平台模型支持,统一使用 one-api 接口形式,其他平台的模型需要通过 one-api 接口添加
## v4.1.0
* bug修复:修复移动端修改聊天标题不生效的问题
diff --git a/api/core/types/task.go b/api/core/types/task.go
index 129a83b4..552b7666 100644
--- a/api/core/types/task.go
+++ b/api/core/types/task.go
@@ -78,3 +78,16 @@ type DallTask struct {
Power int `json:"power"`
}
+
+type SunoTask struct {
+ Id int `json:"id"`
+ UserId string `json:"user_id"`
+ Type int `json:"type"`
+ TaskId string `json:"task_id"`
+ Title string `json:"title"`
+ ReferenceId string `json:"reference_id"`
+ Prompt string `json:"prompt"`
+ Tags string `json:"tags"`
+ Instrumental bool `json:"instrumental"` // 是否纯音乐
+ ExtendSecs int `json:"extend_secs"` // 延长秒杀
+}
diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go
index 3a1e6328..f4a32444 100644
--- a/api/handler/admin/chat_model_handler.go
+++ b/api/handler/admin/chat_model_handler.go
@@ -49,28 +49,33 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
return
}
- item := model.ChatModel{
- Platform: data.Platform,
- Name: data.Name,
- Value: data.Value,
- Enabled: data.Enabled,
- Open: data.Open,
- MaxTokens: data.MaxTokens,
- MaxContext: data.MaxContext,
- Temperature: data.Temperature,
- KeyId: data.KeyId,
- Power: data.Power}
+ item := model.ChatModel{}
+ // 更新
+ if data.Id > 0 {
+ h.DB.Where("id", data.Id).First(&item)
+ }
+
+ item.Name = data.Name
+ item.Value = data.Value
+ item.Enabled = data.Enabled
+ item.SortNum = data.SortNum
+ item.Open = data.Open
+ item.Platform = data.Platform
+ item.Power = data.Power
+ item.MaxTokens = data.MaxTokens
+ item.MaxContext = data.MaxContext
+ item.Temperature = data.Temperature
+ item.KeyId = data.KeyId
+
var res *gorm.DB
if data.Id > 0 {
- item.Id = data.Id
- item.SortNum = data.SortNum
- res = h.DB.Select("*").Omit("created_at").Updates(&item)
+ res = h.DB.Updates(&item)
} else {
res = h.DB.Create(&item)
}
if res.Error != nil {
logger.Error("error with update database:", res.Error)
- resp.ERROR(c, "更新数据库失败!")
+ resp.ERROR(c, res.Error.Error())
return
}
diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go
index 34747d77..c2386efc 100644
--- a/api/handler/admin/user_handler.go
+++ b/api/handler/admin/user_handler.go
@@ -112,7 +112,7 @@ func (h *UserHandler) Save(c *gin.Context) {
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
logger.Error("error with update database:", res.Error)
- resp.ERROR(c, "更新数据库失败!")
+ resp.ERROR(c, res.Error.Error())
return
}
// 记录算力日志
@@ -136,6 +136,13 @@ func (h *UserHandler) Save(c *gin.Context) {
})
}
} else {
+ // 检查用户是否已经存在
+ h.DB.Where("username", data.Username).First(&user)
+ if user.Id > 0 {
+ resp.ERROR(c, "用户名已存在")
+ return
+ }
+
salt := utils.RandString(8)
u := model.User{
Username: data.Username,
diff --git a/api/handler/chat_model_handler.go b/api/handler/chat_model_handler.go
index 555de7c8..1b74f348 100644
--- a/api/handler/chat_model_handler.go
+++ b/api/handler/chat_model_handler.go
@@ -31,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
var res *gorm.DB
+ session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
+ t := c.Query("type")
+ if t != "" {
+ session = session.Where("type", t)
+ }
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
- res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
+ res = session.Where("open", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go
new file mode 100644
index 00000000..5419ddea
--- /dev/null
+++ b/api/handler/suno_handler.go
@@ -0,0 +1,63 @@
+package handler
+
+// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+// * Copyright 2023 The Geek-AI Authors. All rights reserved.
+// * Use of this source code is governed by a Apache-2.0 license
+// * that can be found in the LICENSE file.
+// * @Author yangjian102621@163.com
+// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+import (
+ "geekai/core"
+ "github.com/gin-gonic/gin"
+ "gorm.io/gorm"
+)
+
+type SunoHandler struct {
+ BaseHandler
+}
+
+func NewSunoHandler(app *core.AppServer, db *gorm.DB) *SunoHandler {
+ return &SunoHandler{
+ BaseHandler: BaseHandler{
+ App: app,
+ DB: db,
+ },
+ }
+}
+
+// Client WebSocket 客户端,用于通知任务状态变更
+func (h *SunoHandler) Client(c *gin.Context) {
+ //ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
+ //if err != nil {
+ // logger.Error(err)
+ // c.Abort()
+ // return
+ //}
+ //
+ //userId := h.GetInt(c, "user_id", 0)
+ //if userId == 0 {
+ // logger.Info("Invalid user ID")
+ // c.Abort()
+ // return
+ //}
+ //
+ ////client := types.NewWsClient(ws)
+ //logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
+}
+
+func (h *SunoHandler) Create(c *gin.Context) {
+
+}
+
+func (h *SunoHandler) List(c *gin.Context) {
+
+}
+
+func (h *SunoHandler) Remove(c *gin.Context) {
+
+}
+
+func (h *SunoHandler) Publish(c *gin.Context) {
+
+}
diff --git a/api/service/mj/service.go b/api/service/mj/service.go
index cba2b396..2218c851 100644
--- a/api/service/mj/service.go
+++ b/api/service/mj/service.go
@@ -29,6 +29,7 @@ type Service struct {
notifyQueue *store.RedisQueue
db *gorm.DB
running bool
+ retryCount map[uint]int
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
@@ -39,6 +40,7 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
notifyQueue: notifyQueue,
Client: cli,
running: true,
+ retryCount: make(map[uint]int),
}
}
@@ -57,8 +59,13 @@ func (s *Service) Run() {
// 如果配置了多个中转平台的 API KEY
// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
if task.ChannelId != "" && task.ChannelId != s.Name {
+ if s.retryCount[task.Id] > 5 {
+ s.db.Model(model.MidJourneyJob{Id: task.Id}).Delete(&model.MidJourneyJob{})
+ continue
+ }
logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
s.taskQueue.RPush(task)
+ s.retryCount[task.Id]++
time.Sleep(time.Second)
continue
}
diff --git a/api/service/suno/service.go b/api/service/suno/service.go
new file mode 100644
index 00000000..8f28302a
--- /dev/null
+++ b/api/service/suno/service.go
@@ -0,0 +1,66 @@
+package dalle
+
+// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+// * Copyright 2023 The Geek-AI Authors. All rights reserved.
+// * Use of this source code is governed by a Apache-2.0 license
+// * that can be found in the LICENSE file.
+// * @Author yangjian102621@163.com
+// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+import (
+ "geekai/core/types"
+ logger2 "geekai/logger"
+ "geekai/service/oss"
+ "geekai/store"
+ "github.com/go-redis/redis/v8"
+ "time"
+
+ "github.com/imroc/req/v3"
+ "gorm.io/gorm"
+)
+
+var logger = logger2.GetLogger()
+
+type Service struct {
+ httpClient *req.Client
+ db *gorm.DB
+ uploadManager *oss.UploaderManager
+ taskQueue *store.RedisQueue
+ notifyQueue *store.RedisQueue
+ Clients *types.LMap[uint, *types.WsClient] // UserId => Client
+}
+
+func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
+ return &Service{
+ httpClient: req.C().SetTimeout(time.Minute * 3),
+ db: db,
+ taskQueue: store.NewRedisQueue("Suno_Task_Queue", redisCli),
+ notifyQueue: store.NewRedisQueue("Suno_Notify_Queue", redisCli),
+ Clients: types.NewLMap[uint, *types.WsClient](),
+ uploadManager: manager,
+ }
+}
+
+func (s *Service) PushTask(task types.SunoTask) {
+ logger.Infof("add a new Suno task to the task list: %+v", task)
+ s.taskQueue.RPush(task)
+}
+
+func (s *Service) Run() {
+ logger.Info("Starting Suno job consumer...")
+ go func() {
+ for {
+ var task types.SunoTask
+ err := s.taskQueue.LPop(&task)
+ if err != nil {
+ logger.Errorf("taking task with error: %v", err)
+ continue
+ }
+
+ }
+ }()
+}
+
+func (s *Service) Create(task types.SunoTask) {
+
+}
diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go
index 4fb21051..bc98b626 100644
--- a/api/store/vo/chat_model.go
+++ b/api/store/vo/chat_model.go
@@ -12,6 +12,6 @@ type ChatModel struct {
MaxTokens int `json:"max_tokens"` // 最大响应长度
MaxContext int `json:"max_context"` // 最大上下文长度
Temperature float32 `json:"temperature"` // 模型温度
- KeyId int `json:"key_id"`
+ KeyId int `json:"key_id,omitempty"`
KeyName string `json:"key_name"`
}
diff --git a/web/src/components/FooterBar.vue b/web/src/components/FooterBar.vue
index c30955ca..578f27e9 100644
--- a/web/src/components/FooterBar.vue
+++ b/web/src/components/FooterBar.vue
@@ -1,15 +1,12 @@
-
+
@@ -64,7 +74,7 @@
-
+
@@ -180,6 +190,7 @@ const models = ref([
{label: "v3.0", value: "chirp-v3-0"},
{label: "v3.5", value:"chirp-v3-5"}
])
+const tags = ref([])
const data = ref({
model: "chirp-v3-0",
tags: "",
diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue
index 11e81a2f..08bdb08c 100644
--- a/web/src/views/admin/ApiKey.vue
+++ b/web/src/views/admin/ApiKey.vue
@@ -104,10 +104,10 @@
-
-
- {{
- item.name
+
+
+ {{
+ item.label
}}
@@ -159,13 +159,17 @@ const rules = reactive({
type: [{required: true, message: '请选择用途', trigger: 'change',}],
value: [{required: true, message: '请输入 API KEY 值', trigger: 'change',}]
})
+
const loading = ref(true)
const formRef = ref(null)
const title = ref("")
const platforms = ref([])
const types = ref([
- {name: "聊天", value: "chat"},
- {name: "绘画", value: "img"},
+ {label: "对话", value:"chat"},
+ {label: "Midjourney", value:"mj"},
+ {label: "DALL-E", value:"dall"},
+ {label: "Suno文生歌", value:"suno"},
+ {label: "Luma视频", value:"luma"},
])
diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue
index 4a718615..1c3dc1f9 100644
--- a/web/src/views/admin/ChatModel.vue
+++ b/web/src/views/admin/ChatModel.vue
@@ -46,11 +46,6 @@
-
-
-
-
-
@@ -89,7 +84,7 @@
-
+