mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 02:03:42 +08:00
add function to generate lyrics
This commit is contained in:
@@ -83,7 +83,7 @@ func errorHandler(c *gin.Context) {
|
||||
if r := recover(); r != nil {
|
||||
logger.Errorf("Handler Panic: %v", r)
|
||||
debug.PrintStack()
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
|
||||
c.JSON(http.StatusBadRequest, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
@@ -225,6 +225,8 @@ func needLogin(c *gin.Context) bool {
|
||||
c.Request.URL.Path == "/api/payment/doPay" ||
|
||||
c.Request.URL.Path == "/api/payment/payWays" ||
|
||||
c.Request.URL.Path == "/api/suno/client" ||
|
||||
c.Request.URL.Path == "/api/suno/Detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/config/") ||
|
||||
|
||||
@@ -113,7 +113,7 @@ func (h *SunoHandler) Create(c *gin.Context) {
|
||||
RefTaskId: data.RefTaskId,
|
||||
RefSongId: data.RefSongId,
|
||||
ExtendSecs: data.ExtendSecs,
|
||||
Prompt: data.Prompt,
|
||||
Prompt: job.Prompt,
|
||||
Tags: data.Tags,
|
||||
Model: data.Model,
|
||||
Instrumental: data.Instrumental,
|
||||
@@ -265,13 +265,13 @@ func (h *SunoHandler) Update(c *gin.Context) {
|
||||
|
||||
// Detail 歌曲详情
|
||||
func (h *SunoHandler) Detail(c *gin.Context) {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
if id <= 0 {
|
||||
songId := c.Query("song_id")
|
||||
if songId == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
var item model.SunoJob
|
||||
if err := h.DB.Where("id", id).First(&item).Error; err != nil {
|
||||
if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -296,3 +296,50 @@ func (h *SunoHandler) Detail(c *gin.Context) {
|
||||
|
||||
resp.SUCCESS(c, itemVo)
|
||||
}
|
||||
|
||||
// Play 增加歌曲播放次数
|
||||
func (h *SunoHandler) Play(c *gin.Context) {
|
||||
songId := c.Query("song_id")
|
||||
if songId == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
|
||||
}
|
||||
|
||||
const genLyricTemplate = `
|
||||
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
|
||||
请以【%s】为主题创作一首歌曲,歌曲时间不要太短,3分钟左右,不要输出任何解释性的内容。
|
||||
输出格式如下:
|
||||
歌曲名称
|
||||
第一节:
|
||||
{{歌词内容}}
|
||||
副歌:
|
||||
{{歌词内容}}
|
||||
|
||||
第二节:
|
||||
{{歌词内容}}
|
||||
副歌:
|
||||
{{歌词内容}}
|
||||
|
||||
尾声:
|
||||
{{歌词内容}}
|
||||
`
|
||||
|
||||
// Lyric 生成歌词
|
||||
func (h *SunoHandler) Lyric(c *gin.Context) {
|
||||
var data struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp.SUCCESS(c, content)
|
||||
}
|
||||
|
||||
@@ -494,6 +494,8 @@ func main() {
|
||||
group.GET("publish", h.Publish)
|
||||
group.POST("update", h.Update)
|
||||
group.GET("detail", h.Detail)
|
||||
group.GET("play", h.Play)
|
||||
group.POST("lyric", h.Lyric)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
go func() {
|
||||
|
||||
@@ -110,12 +110,11 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt))
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, prompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var user model.User
|
||||
|
||||
@@ -72,7 +72,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
@@ -81,7 +81,7 @@ func (s *Service) Run() {
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
|
||||
@@ -63,7 +63,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Params.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt))
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.Params.Prompt = content
|
||||
} else {
|
||||
@@ -73,7 +73,7 @@ func (s *Service) Run() {
|
||||
|
||||
// translate negative prompt
|
||||
if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt))
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt), "gpt-4o-mini")
|
||||
if err == nil {
|
||||
task.Params.NegPrompt = content
|
||||
} else {
|
||||
|
||||
@@ -8,12 +8,14 @@ package utils
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -43,16 +45,7 @@ type apiRes struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type apiErrRes struct {
|
||||
Error struct {
|
||||
Code interface{} `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param interface{} `json:"param"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
|
||||
func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) {
|
||||
var apiKey model.ApiKey
|
||||
res := db.Where("type", "chat").Where("enabled", true).First(&apiKey)
|
||||
if res.Error != nil {
|
||||
@@ -66,24 +59,27 @@ func OpenAIRequest(db *gorm.DB, prompt string) (string, error) {
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
var errRes apiErrRes
|
||||
client := req.C()
|
||||
if len(apiKey.ProxyURL) > 5 {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
r, err := client.R().SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "gpt-3.5-turbo",
|
||||
Model: modelName,
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: messages,
|
||||
}).
|
||||
SetErrorResult(&errRes).
|
||||
SetSuccessResult(&response).Post(apiKey.ApiURL)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message)
|
||||
}).Post(apiURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求 OpenAI API失败:%v", err)
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
|
||||
Reference in New Issue
Block a user