merge v4.1.1 and fixed conflicts

This commit is contained in:
RockYang
2024-11-13 18:40:04 +08:00
110 changed files with 4820 additions and 1977 deletions

View File

@@ -55,12 +55,6 @@ func (h *ManagerHandler) Login(c *gin.Context) {
return
}
//// add captcha
//if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
// resp.ERROR(c, "验证码错误!")
// return
//}
var manager model.AdminUser
res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
if res.Error != nil {

View File

@@ -31,7 +31,6 @@ func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
func (h *ApiKeyHandler) Save(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Platform string `json:"platform"`
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
@@ -48,7 +47,6 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
if data.Id > 0 {
h.DB.Find(&apiKey, data.Id)
}
apiKey.Platform = data.Platform
apiKey.Value = data.Value
apiKey.Type = data.Type
apiKey.ApiURL = data.ApiURL

View File

@@ -49,28 +49,32 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
return
}
item := model.ChatModel{
Platform: data.Platform,
Name: data.Name,
Value: data.Value,
Enabled: data.Enabled,
Open: data.Open,
MaxTokens: data.MaxTokens,
MaxContext: data.MaxContext,
Temperature: data.Temperature,
KeyId: data.KeyId,
Power: data.Power}
item := model.ChatModel{}
// 更新
if data.Id > 0 {
h.DB.Where("id", data.Id).First(&item)
}
item.Name = data.Name
item.Value = data.Value
item.Enabled = data.Enabled
item.SortNum = data.SortNum
item.Open = data.Open
item.Power = data.Power
item.MaxTokens = data.MaxTokens
item.MaxContext = data.MaxContext
item.Temperature = data.Temperature
item.KeyId = data.KeyId
var res *gorm.DB
if data.Id > 0 {
item.Id = data.Id
item.SortNum = data.SortNum
res = h.DB.Select("*").Omit("created_at").Updates(&item)
res = h.DB.Save(&item)
} else {
res = h.DB.Create(&item)
}
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
resp.ERROR(c, res.Error.Error())
return
}
@@ -89,12 +93,12 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
func (h *ChatModelHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{})
enable := h.GetBool(c, "enable")
platform := h.GetTrim(c, "platform")
name := h.GetTrim(c, "name")
if enable {
session = session.Where("enabled", enable)
}
if platform != "" {
session = session.Where("platform", platform)
if name != "" {
session = session.Where("name LIKE ?", name+"%")
}
var items []model.ChatModel
var cms = make([]vo.ChatModel, 0)

View File

@@ -50,6 +50,7 @@ func (h *ConfigHandler) Update(c *gin.Context) {
Content string `json:"content,omitempty"`
Updated bool `json:"updated,omitempty"`
} `json:"config"`
ConfigBak types.SystemConfig `json:"config_bak,omitempty"`
}
if err := c.ShouldBindJSON(&data); err != nil {
@@ -57,6 +58,12 @@ func (h *ConfigHandler) Update(c *gin.Context) {
return
}
// ONLY authorized user can change the copyright
if (data.Key == "system" && data.Config.Copyright != data.ConfigBak.Copyright) && !h.licenseService.GetLicense().Configs.DeCopy {
resp.ERROR(c, "您无权修改版权信息,请先联系作者获取授权")
return
}
value := utils.JsonEncode(&data.Config)
config := model.Config{Key: data.Key, Config: value}
res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
@@ -143,10 +150,9 @@ func (h *ConfigHandler) GetLicense(c *gin.Context) {
// GetAppConfig 获取内置配置
func (h *ConfigHandler) GetAppConfig(c *gin.Context) {
resp.SUCCESS(c, gin.H{
"mj_plus": h.App.Config.MjPlusConfigs,
"mj_proxy": h.App.Config.MjProxyConfigs,
"sd": h.App.Config.SdConfigs,
"platforms": Platforms,
"mj_plus": h.App.Config.MjPlusConfigs,
"mj_proxy": h.App.Config.MjProxyConfigs,
"sd": h.App.Config.SdConfigs,
})
}

View File

@@ -1,12 +0,0 @@
package admin
import "geekai/core/types"
var Platforms = []types.Platform{
types.OpenAI,
types.QWen,
types.XunFei,
types.ChatGLM,
types.Baidu,
types.Azure,
}

View File

@@ -112,7 +112,7 @@ func (h *UserHandler) Save(c *gin.Context) {
res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
if res.Error != nil {
logger.Error("error with update database", res.Error)
resp.ERROR(c, "更新数据库失败!")
resp.ERROR(c, res.Error.Error())
return
}
// 记录算力日志
@@ -136,6 +136,13 @@ func (h *UserHandler) Save(c *gin.Context) {
})
}
} else {
// 检查用户是否已经存在
h.DB.Where("username", data.Username).First(&user)
if user.Id > 0 {
resp.ERROR(c, "用户名已存在")
return
}
salt := utils.RandString(8)
u := model.User{
Username: data.Username,

View File

@@ -31,9 +31,14 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var items []model.ChatModel
var chatModels = make([]vo.ChatModel, 0)
var res *gorm.DB
session := h.DB.Session(&gorm.Session{}).Where("enabled", true)
t := c.Query("type")
if t != "" {
session = session.Where("type", t)
}
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
res = session.Where("open", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int

View File

@@ -29,45 +29,32 @@ func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
// List 获取用户聊天应用列表
func (h *ChatRoleHandler) List(c *gin.Context) {
all := h.GetBool(c, "all")
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var roles []model.ChatRole
var roleVos = make([]vo.ChatRole, 0)
query := h.DB.Where("enable", true)
if userId > 0 {
var user model.User
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
return
}
query = query.Where("marker IN ?", roleKeys)
}
if id > 0 {
query = query.Or("id", id)
}
res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
if res.Error != nil {
resp.SUCCESS(c, roleVos)
return
}
// 获取所有角色
if userId == 0 || all {
// 转成 vo
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {
v.Id = r.Id
roleVos = append(roleVos, v)
}
}
resp.SUCCESS(c, roleVos)
return
}
var user model.User
h.DB.First(&user, userId)
var roleKeys []string
err := utils.JsonDecode(user.ChatRoles, &roleKeys)
if err != nil {
resp.ERROR(c, "角色解析失败!")
resp.ERROR(c, res.Error.Error())
return
}
var roleVos = make([]vo.ChatRole, 0)
for _, r := range roles {
if !utils.Contains(roleKeys, r.Key) {
continue
}
var v vo.ChatRole
err := utils.CopyObject(r, &v)
if err == nil {

View File

@@ -1,111 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"io"
"strings"
"time"
)
// 微软 Azure 模型消息发送实现
func (h *ChatHandler) sendAzureMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
return errors.New(line)
}
if len(responseBody.Choices) == 0 {
continue
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content))
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}

View File

@@ -1,185 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"io"
"net/http"
"strings"
"time"
)
type baiduResp struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
IsTruncated bool `json:"is_truncated"`
Result string `json:"result"`
NeedClearHistory bool `json:"need_clear_history"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// 百度文心一言消息发送实现
func (h *ChatHandler) sendBaiduMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
logger.Error(err)
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
var resp baiduResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
}
if resp.IsEnd {
break
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}
func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 24 * 20 // access_token 有效期
key := strings.Split(apiKey, "|")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error with send request: %w", err)
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("error with read response: %w", err)
}
var r map[string]interface{}
err = json.Unmarshal(body, &r)
if err != nil {
return "", fmt.Errorf("error with parse response: %w", err)
}
if r["error"] != nil {
return "", fmt.Errorf("error with api response: %s", r["error_description"])
}
tokenString = fmt.Sprintf("%s", r["access_token"])
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, nil
}

View File

@@ -116,8 +116,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
MaxTokens: chatModel.MaxTokens,
MaxContext: chatModel.MaxContext,
Temperature: chatModel.Temperature,
KeyId: chatModel.KeyId,
Platform: chatModel.Platform}
KeyId: chatModel.KeyId}
logger.Infof("New websocket connected, IP: %s", c.ClientIP())
go func() {
@@ -208,21 +207,12 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
Model: session.Model.Value,
Stream: true,
}
switch session.Model.Platform {
case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
break
case types.OpenAI.Value:
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
req.Temperature = session.Model.Temperature
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res = h.DB.Where("enabled", true).Find(&items)
if res.Error == nil {
var tools = make([]types.Tool, 0)
for _, v := range items {
var parameters map[string]interface{}
@@ -248,15 +238,6 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.Tools = tools
req.ToolChoice = "auto"
}
case types.QWen.Value:
req.Parameters = map[string]interface{}{
"max_tokens": session.Model.MaxTokens,
"temperature": session.Model.Temperature,
}
break
default:
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
}
// 加载聊天上下文
@@ -344,65 +325,37 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
logger.Debug("最终Prompt", fullPrompt)
if session.Model.Platform == types.QWen.Value {
req.Input = make(map[string]interface{})
reqMgs = append(reqMgs, types.Message{
Role: "user",
Content: fullPrompt,
})
req.Input["messages"] = reqMgs
} else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
}
// extract images from prompt
imgURLs := utils.ExtractImgURLs(prompt)
logger.Debugf("detected IMG: %+v", imgURLs)
var content interface{}
if len(imgURLs) > 0 {
data := make([]interface{}, 0)
for _, v := range imgURLs {
text = strings.Replace(text, v, "", 1)
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
"type": "image_url",
"image_url": gin.H{
"url": v,
},
})
content = data
} else {
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
data = append(data, gin.H{
"type": "text",
"text": strings.TrimSpace(text),
})
content = data
} else {
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": fullPrompt,
})
content = fullPrompt
}
req.Messages = append(reqMgs, map[string]interface{}{
"role": "user",
"content": content,
})
logger.Debugf("%+v", req.Messages)
switch session.Model.Platform {
case types.Azure.Value:
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.OpenAI.Value:
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.ChatGLM.Value:
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu.Value:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei.Value:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.QWen.Value:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
return nil
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
// Tokens 统计 token 数量
@@ -478,55 +431,20 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
}
// use the last unused key
if apiKey.Id == 0 {
h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
h.DB.Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
}
if apiKey.Id == 0 {
return nil, errors.New("no available key, please import key")
}
// ONLY allow apiURL in blank list
if session.Model.Platform == types.OpenAI.Value {
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
err := h.licenseService.IsValidApiURL(apiKey.ApiURL)
if err != nil {
return nil, err
}
var apiURL string
switch session.Model.Platform {
case types.Azure.Value:
md := strings.Replace(req.Model, ".", "", 1)
apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
break
case types.ChatGLM.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
req.Messages = nil
break
case types.Baidu.Value:
apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
break
case types.QWen.Value:
apiURL = apiKey.ApiURL
req.Messages = nil
break
default:
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if session.Model.Platform == types.Baidu.Value {
token, err := h.getBaiduToken(apiKey.Value)
if err != nil {
return nil, err
}
logger.Info("百度文心 Access_Token", token)
apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
}
logger.Debugf(utils.JsonEncode(req))
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 创建 HttpClient 请求对象
var client *http.Client
requestBody, err := json.Marshal(req)
@@ -550,28 +468,10 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, sessi
} else {
client = http.DefaultClient
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
switch session.Model.Platform {
case types.Azure.Value:
request.Header.Set("api-key", apiKey.Value)
break
case types.ChatGLM.Value:
token, err := h.getChatGLMToken(apiKey.Value)
if err != nil {
return nil, err
}
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
break
case types.Baidu.Value:
request.RequestURI = ""
case types.OpenAI.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
break
case types.QWen.Value:
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
request.Header.Set("X-DashScope-SSE", "enable")
break
}
logger.Debugf("Sending %s request, Channel:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiKey.ApiURL, apiURL, apiKey.ProxyURL, req.Model)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
return client.Do(request)
}
@@ -708,7 +608,7 @@ func (h *ChatHandler) extractImgUrl(text string) string {
continue
}
newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
newImgURL, err := h.uploadManager.GetUploadHandler().PutUrlFile(imageURL, false)
if err != nil {
logger.Error("error with download image: ", err)
continue

View File

@@ -1,142 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/golang-jwt/jwt/v5"
"io"
"strings"
"time"
)
// 清华大学 ChatGML 消息发送实现
func (h *ChatHandler) sendChatGLMMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var event, content string
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "event:") {
event = line[6:]
continue
}
if strings.HasPrefix(line, "data:") {
content = line[5:]
}
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
switch event {
case "add":
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
contents = append(contents, content)
case "finish":
break
case "error":
utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content))
break
case "interrupted":
utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**")
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}
func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) {
ctx := context.Background()
tokenString, err := h.redis.Get(ctx, apiKey).Result()
if err == nil {
return tokenString, nil
}
expr := time.Hour * 2
key := strings.Split(apiKey, ".")
if len(key) != 2 {
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"api_key": key[0],
"timestamp": time.Now().Unix(),
"exp": time.Now().Add(expr).Add(time.Second * 10).Unix(),
})
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
delete(token.Header, "typ")
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString([]byte(key[1]))
h.redis.Set(ctx, apiKey, tokenString, expr)
return tokenString, err
}

View File

@@ -65,7 +65,6 @@ func (h *ChatHandler) sendOpenAiMessage(
if !strings.Contains(line, "data:") || len(line) < 30 {
continue
}
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
@@ -74,7 +73,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
}
if responseBody.Choices[0].Delta.Content == nil {
if responseBody.Choices[0].Delta.Content == nil && responseBody.Choices[0].Delta.ToolCalls == nil {
continue
}

View File

@@ -1,150 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"bufio"
"context"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/syndtr/goleveldb/leveldb/errors"
"io"
"strings"
"time"
)
type qWenResp struct {
Output struct {
FinishReason string `json:"finish_reason"`
Text string `json:"text"`
} `json:"output,omitempty"`
Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// 通义千问消息发送实现
func (h *ChatHandler) sendQWenMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = model.ApiKey{}
response, err := h.doRequest(ctx, req, session, &apiKey)
logger.Info("HTTP请求完成耗时", time.Now().Sub(start))
if err != nil {
if strings.Contains(err.Error(), "context canceled") {
return fmt.Errorf("用户取消了请求:%s", prompt)
} else if strings.Contains(err.Error(), "no available key") {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
return err
} else {
defer response.Body.Close()
}
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
scanner := bufio.NewScanner(response.Body)
var content, lastText, newText string
var outPutStart = false
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") ||
strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
continue
}
if !strings.HasPrefix(line, "data:") {
continue
}
content = line[5:]
var resp qWenResp
if len(contents) == 0 { // 发送消息头
if !outPutStart {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
outPutStart = true
continue
} else {
// 处理代码换行
content = "\n"
}
} else {
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", content)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if resp.Message != "" {
utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
break
}
}
//通过比较 lastText上一次的文本和 currentText当前的文本
//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
//每次循环结束后lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
currentText := resp.Output.Text
if currentText != lastText {
// 提取新增文本
newText = strings.Replace(currentText, lastText, "", 1)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(newText),
})
lastText = currentText // 更新 lastText
}
contents = append(contents, newText)
if resp.Output.FinishReason == "stop" {
break
}
} //end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
} else {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("请求大模型 API 失败:%s", body)
}
return nil
}

View File

@@ -1,255 +0,0 @@
package chatimpl
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type xunFeiResp struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
} `json:"text"`
} `json:"choices"`
Usage struct {
Text struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
var Model2URL = map[string]string{
"general": "v1.1",
"generalv2": "v2.1",
"generalv3": "v3.1",
"generalv3.5": "v3.5",
}
// 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage(
chatCtx []types.Message,
req types.ApiRequest,
userVo vo.User,
ctx context.Context,
session *types.ChatSession,
role model.ChatRole,
prompt string,
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
var res *gorm.DB
// use the bind key
if session.Model.KeyId > 0 {
res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
key := strings.Split(apiKey.Value, "|")
if len(key) != 3 {
utils.ReplyMessage(ws, "非法的 API KEY")
return nil
}
apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
if err != nil {
logger.Error(readResp(resp) + err.Error())
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
} else if resp.StatusCode != 101 {
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
}
data := buildRequest(key[0], req)
fmt.Printf("%+v", data)
fmt.Println(apiURL)
err = conn.WriteJSON(data)
if err != nil {
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
return nil
}
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.Error("error with read message:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
break
}
// 解析数据
var result xunFeiResp
err = json.Unmarshal(msg, &result)
if err != nil {
logger.Error("error with parsing JSON:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
return nil
}
if result.Header.Code != 0 {
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
return nil
}
content = result.Payload.Choices.Text[0].Content
// 处理代码换行
if len(content) == 0 {
content = "\n"
}
contents = append(contents, content)
// 第一个结果
if result.Payload.Choices.Status == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
if result.Payload.Choices.Status == 2 { // 最终结果
_ = conn.Close() // 关闭连接
break
}
select {
case <-ctx.Done():
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
return nil
default:
continue
}
}
// 消息发送成功
if len(contents) > 0 {
h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt)
}
return nil
}
// 构建 websocket 请求实体
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
return map[string]interface{}{
"header": map[string]interface{}{
"app_id": appid,
},
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": req.Temperature,
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",
},
},
"payload": map[string]interface{}{
"message": map[string]interface{}{
"text": req.Messages,
},
},
}
}
// 创建鉴权 URL
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
ul, err := url.Parse(hostURL)
if err != nil {
return "", err
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
signStr := strings.Join(signString, "\n")
sha := hmacWithSha256(signStr, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
return hostURL + "?" + v.Encode(), nil
}
// 使用 sha256 签名
func hmacWithSha256(data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
// 读取响应
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}

View File

@@ -212,21 +212,21 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
}
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
session := h.DB.Session(&gorm.Session{})
// if the chat model bind a KEY, use it directly
var res *gorm.DB
if chatModel.KeyId > 0 {
res = h.DB.Where("id", chatModel.KeyId).Find(apiKey)
}
// use the last unused key
if apiKey.Id == 0 {
res = h.DB.Where("platform", types.OpenAI.Value).
Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC").First(apiKey)
session = session.Where("id", chatModel.KeyId)
} else { // use the last unused key
session = session.Where("type", "chat").
Where("enabled", true).Order("last_used_at ASC")
}
res := session.First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
apiURL := apiKey.ApiURL
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
// 更新 API KEY 的最后使用时间
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())

View File

@@ -27,9 +27,15 @@ func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
// List 数据列表
func (h *MenuHandler) List(c *gin.Context) {
index := h.GetBool(c, "index")
var items []model.Menu
var list = make([]vo.Menu, 0)
res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
session := h.DB.Session(&gorm.Session{})
session = session.Where("enabled", true)
if index {
session = session.Where("id IN ?", h.App.SysConfig.IndexNavs)
}
res := session.Order("sort_num ASC").Find(&items)
if res.Error == nil {
for _, item := range items {
var product vo.Menu

View File

@@ -406,7 +406,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
session = session.Where("progress >= ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
@@ -456,15 +456,44 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
resp.ERROR(c, "记录不存在")
return
}
// remove job recode
res := h.DB.Delete(&job)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
tx := h.DB.Begin()
if err := tx.Delete(&job).Error; err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
// refund power
err := tx.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
var user model.User
h.DB.Where("id = ?", job.UserId).First(&user)
err = tx.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "mid-journey",
Remark: fmt.Sprintf("绘画任务失败退回算力。任务ID%s", job.TaskId),
CreatedAt: time.Now(),
}).Error
if err != nil {
tx.Rollback()
resp.ERROR(c, err.Error())
return
}
tx.Commit()
// remove image
err := h.uploader.GetUploadHandler().Delete(job.ImgURL)
err = h.uploader.GetUploadHandler().Delete(job.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}

345
api/handler/suno_handler.go Normal file
View File

@@ -0,0 +1,345 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service/oss"
"geekai/service/suno"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"time"
)
type SunoHandler struct {
BaseHandler
service *suno.Service
uploader *oss.UploaderManager
}
func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, uploader *oss.UploaderManager) *SunoHandler {
return &SunoHandler{
BaseHandler: BaseHandler{
App: app,
DB: db,
},
service: service,
uploader: uploader,
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *SunoHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SunoHandler) Create(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
Instrumental bool `json:"instrumental"`
Lyrics string `json:"lyrics"`
Model string `json:"model"`
Tags string `json:"tags"`
Title string `json:"title"`
Type int `json:"type"`
RefTaskId string `json:"ref_task_id"` // 续写的任务id
ExtendSecs int `json:"extend_secs"` // 续写秒数
RefSongId string `json:"ref_song_id"` // 续写的歌曲id
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// 插入数据库
job := model.SunoJob{
UserId: int(h.GetLoginUserId(c)),
Prompt: data.Prompt,
Instrumental: data.Instrumental,
ModelName: data.Model,
Tags: data.Tags,
Title: data.Title,
Type: data.Type,
RefSongId: data.RefSongId,
RefTaskId: data.RefTaskId,
ExtendSecs: data.ExtendSecs,
Power: h.App.SysConfig.SunoPower,
}
if data.Lyrics != "" {
job.Prompt = data.Lyrics
}
tx := h.DB.Create(&job)
if tx.Error != nil {
resp.ERROR(c, tx.Error.Error())
return
}
// 创建任务
h.service.PushTask(types.SunoTask{
Id: job.Id,
UserId: job.UserId,
Type: job.Type,
Title: job.Title,
RefTaskId: data.RefTaskId,
RefSongId: data.RefSongId,
ExtendSecs: data.ExtendSecs,
Prompt: job.Prompt,
Tags: data.Tags,
Model: data.Model,
Instrumental: data.Instrumental,
})
// update user's power
tx = h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
user, _ := h.GetLoginUser(c)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: job.Power,
Balance: user.Power - job.Power,
Mark: types.PowerSub,
Model: job.ModelName,
Remark: fmt.Sprintf("Suno 文生歌曲,%s", job.ModelName),
CreatedAt: time.Now(),
})
}
client := h.service.Clients.Get(uint(job.UserId))
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
func (h *SunoHandler) List(c *gin.Context) {
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
session := h.DB.Session(&gorm.Session{}).Where("user_id", userId)
// 统计总数
var total int64
session.Model(&model.SunoJob{}).Count(&total)
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var list []model.SunoJob
err := session.Order("id desc").Find(&list).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 初始化续写关系
songIds := make([]string, 0)
for _, v := range list {
if v.RefTaskId != "" {
songIds = append(songIds, v.RefSongId)
}
}
var tasks []model.SunoJob
h.DB.Where("song_id IN ?", songIds).Find(&tasks)
songMap := make(map[string]model.SunoJob)
for _, t := range tasks {
songMap[t.SongId] = t
}
// 转换为 VO
items := make([]vo.SunoJob, 0)
for _, v := range list {
var item vo.SunoJob
err = utils.CopyObject(v, &item)
if err != nil {
continue
}
item.CreatedAt = v.CreatedAt.Unix()
if s, ok := songMap[v.RefSongId]; ok {
item.RefSong = map[string]interface{}{
"id": s.Id,
"title": s.Title,
"cover": s.CoverURL,
"audio": s.AudioURL,
}
}
items = append(items, item)
}
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
func (h *SunoHandler) Remove(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
var job model.SunoJob
err := h.DB.Where("id = ?", id).Where("user_id", userId).First(&job).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 删除任务
h.DB.Delete(&job)
// 删除文件
_ = h.uploader.GetUploadHandler().Delete(job.CoverURL)
_ = h.uploader.GetUploadHandler().Delete(job.AudioURL)
}
func (h *SunoHandler) Publish(c *gin.Context) {
id := h.GetInt(c, "id", 0)
userId := h.GetLoginUserId(c)
publish := h.GetBool(c, "publish")
err := h.DB.Model(&model.SunoJob{}).Where("id", id).Where("user_id", userId).UpdateColumn("publish", publish).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
func (h *SunoHandler) Update(c *gin.Context) {
var data struct {
Id int `json:"id"`
Title string `json:"title"`
Cover string `json:"cover"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id == 0 || data.Title == "" || data.Cover == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
userId := h.GetLoginUserId(c)
var item model.SunoJob
if err := h.DB.Where("id", data.Id).Where("user_id", userId).First(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
item.Title = data.Title
item.CoverURL = data.Cover
if err := h.DB.Updates(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}
// Detail 歌曲详情
func (h *SunoHandler) Detail(c *gin.Context) {
songId := c.Query("song_id")
if songId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
var item model.SunoJob
if err := h.DB.Where("song_id", songId).First(&item).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
// 读取用户信息
var user model.User
if err := h.DB.Where("id", item.UserId).First(&user).Error; err != nil {
resp.ERROR(c, err.Error())
return
}
var itemVo vo.SunoJob
if err := utils.CopyObject(item, &itemVo); err != nil {
resp.ERROR(c, err.Error())
return
}
itemVo.CreatedAt = item.CreatedAt.Unix()
itemVo.User = map[string]interface{}{
"nickname": user.Nickname,
"avatar": user.Avatar,
}
resp.SUCCESS(c, itemVo)
}
// Play 增加歌曲播放次数
func (h *SunoHandler) Play(c *gin.Context) {
songId := c.Query("song_id")
if songId == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
h.DB.Model(&model.SunoJob{}).Where("song_id", songId).UpdateColumn("play_times", gorm.Expr("play_times + ?", 1))
}
const genLyricTemplate = `
你是一位才华横溢的作曲家,拥有丰富的情感和细腻的笔触,你对文字有着独特的感悟力,能将各种情感和意境巧妙地融入歌词中。
请以【%s】为主题创作一首歌曲歌曲时间不要太短3分钟左右不要输出任何解释性的内容。
输出格式如下:
歌曲名称
第一节:
{{歌词内容}}
副歌:
{{歌词内容}}
第二节:
{{歌词内容}}
副歌:
{{歌词内容}}
尾声:
{{歌词内容}}
`
// Lyric 生成歌词
func (h *SunoHandler) Lyric(c *gin.Context) {
var data struct {
Prompt string `json:"prompt"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
content, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(genLyricTemplate, data.Prompt), "gpt-4o-mini")
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, content)
}