dalle image page is ready

This commit is contained in:
RockYang
2024-04-21 20:23:47 +08:00
parent ab8240613e
commit 60cf380f96
21 changed files with 899 additions and 603 deletions

View File

@@ -8,6 +8,7 @@ import (
"chatplus/utils/resp"
"context"
"fmt"
"github.com/chai2010/webp"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v5"
@@ -16,7 +17,6 @@ import (
"image"
"image/jpeg"
"io"
"log"
"net/http"
"os"
"runtime/debug"
@@ -215,6 +215,8 @@ func needLogin(c *gin.Context) bool {
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/imgWall" ||
c.Request.URL.Path == "/api/sd/client" ||
c.Request.URL.Path == "/api/dall/imgWall" ||
c.Request.URL.Path == "/api/dall/client" ||
c.Request.URL.Path == "/api/config/get" ||
c.Request.URL.Path == "/api/product/list" ||
c.Request.URL.Path == "/api/menu/list" ||
@@ -328,6 +330,10 @@ func staticResourceMiddleware() gin.HandlerFunc {
// 解码图片
img, _, err := image.Decode(file)
// for .webp image
if err != nil {
img, err = webp.Decode(file)
}
if err != nil {
c.String(http.StatusInternalServerError, "Error decoding image")
return
@@ -344,7 +350,9 @@ func staticResourceMiddleware() gin.HandlerFunc {
var buffer bytes.Buffer
err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
if err != nil {
log.Fatal(err)
logger.Error(err)
c.String(http.StatusInternalServerError, err.Error())
return
}
// 设置图片缓存有效期为一年 (365天)

View File

@@ -59,3 +59,16 @@ type SdTaskParams struct {
HdScaleAlg string `json:"hd_scale_alg"` // 放大算法
HdSteps int `json:"hd_steps"` // 高清修复迭代步数
}
// DallTask DALL-E task
type DallTask struct {
JobId uint `json:"job_id"`
UserId uint `json:"user_id"`
Prompt string `json:"prompt"`
N int `json:"n"`
Quality string `json:"quality"`
Size string `json:"size"`
Style string `json:"style"`
Power int `json:"power"`
}

View File

@@ -32,9 +32,10 @@ require (
)
require (
github.com/chai2010/webp v1.1.1 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 // indirect
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 // indirect
)
require (

View File

@@ -12,6 +12,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -241,6 +243,8 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=

View File

@@ -104,8 +104,10 @@ func (h *ChatHandler) sendOpenAiMessage(
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)})
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
contents = append(contents, callMsg)
}
continue
}

View File

@@ -0,0 +1,260 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/dalle"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
type DallJobHandler struct {
BaseHandler
redis *redis.Client
service *dalle.Service
uploader *oss.UploaderManager
}
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
return &DallJobHandler{
service: service,
uploader: manager,
BaseHandler: BaseHandler{
App: app,
DB: db,
},
}
}
// Client WebSocket 客户端,用于通知任务状态变更
func (h *DallJobHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
logger.Info("Invalid user ID")
c.Abort()
return
}
client := types.NewWsClient(ws)
h.service.Clients.Put(uint(userId), client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
client.Close()
h.service.Clients.Delete(uint(userId))
return
}
var message types.WsMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
// 心跳消息
if message.Type == "heartbeat" {
logger.Debug("收到 DallE 心跳消息:", message.Content)
continue
}
}
}()
}
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.SdPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}
return true
}
// Image 创建一个绘画任务
func (h *DallJobHandler) Image(c *gin.Context) {
if !h.preCheck(c) {
return
}
var data types.DallTask
if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
resp.ERROR(c, types.InvalidArgs)
return
}
idValue, _ := c.Get(types.LoginUserID)
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
job := model.DallJob{
UserId: uint(userId),
Prompt: data.Prompt,
Power: h.App.SysConfig.DallPower,
}
res := h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "error with save job: "+res.Error.Error())
return
}
h.service.PushTask(types.DallTask{
JobId: job.Id,
UserId: uint(userId),
Prompt: data.Prompt,
Quality: data.Quality,
Size: data.Size,
Style: data.Style,
Power: job.Power,
})
client := h.service.Clients.Get(job.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
resp.SUCCESS(c)
}
// ImgWall 照片墙
func (h *DallJobHandler) ImgWall(c *gin.Context) {
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
err, jobs := h.getData(true, 0, page, pageSize, true)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取 SD 任务列表
func (h *DallJobHandler) JobList(c *gin.Context) {
status := h.GetBool(c, "status")
userId := h.GetLoginUserId(c)
page := h.GetInt(c, "page", 0)
pageSize := h.GetInt(c, "page_size", 0)
publish := h.GetBool(c, "publish")
err, jobs := h.getData(status, userId, page, pageSize, publish)
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c, jobs)
}
// JobList 获取任务列表
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
session := h.DB.Session(&gorm.Session{})
if finish {
session = session.Where("progress = ?", 100).Order("id DESC")
} else {
session = session.Where("progress < ?", 100).Order("id ASC")
}
if userId > 0 {
session = session.Where("user_id = ?", userId)
}
if publish {
session = session.Where("publish", publish)
}
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
session = session.Offset(offset).Limit(pageSize)
}
var items []model.DallJob
res := session.Find(&items)
if res.Error != nil {
return res.Error, nil
}
var jobs = make([]vo.DallJob, 0)
for _, item := range items {
// delete failed or timeout tasks
if (item.Progress < 100 && time.Now().Sub(item.CreatedAt) > time.Minute*5) || item.Progress == -1 {
h.DB.Delete(&item)
}
var job vo.DallJob
err := utils.CopyObject(item, &job)
if err != nil {
continue
}
jobs = append(jobs, job)
}
return nil, jobs
}
// Remove remove task image
func (h *DallJobHandler) Remove(c *gin.Context) {
var data struct {
Id uint `json:"id"`
UserId uint `json:"user_id"`
ImgURL string `json:"img_url"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
// remove job recode
res := h.DB.Delete(&model.DallJob{Id: data.Id})
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
}
// remove image
err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
if err != nil {
logger.Error("remove image failed: ", err)
}
resp.SUCCESS(c)
}
// Publish 发布/取消发布图片到画廊显示
func (h *DallJobHandler) Publish(c *gin.Context) {
var data struct {
Id uint `json:"id"`
Action bool `json:"action"` // 发布动作true => 发布false => 取消分享
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
if res.Error != nil {
resp.ERROR(c, "更新数据库失败")
return
}
resp.SUCCESS(c)
}

View File

@@ -3,27 +3,35 @@ package handler
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/service/dalle"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"errors"
"fmt"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/imroc/req/v3"
"gorm.io/gorm"
"strings"
"time"
)
type FunctionHandler struct {
BaseHandler
config types.ChatPlusApiConfig
uploadManager *oss.UploaderManager
dallService *dalle.Service
}
func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppConfig, manager *oss.UploaderManager) *FunctionHandler {
func NewFunctionHandler(
server *core.AppServer,
db *gorm.DB,
config *types.AppConfig,
manager *oss.UploaderManager,
dallService *dalle.Service) *FunctionHandler {
return &FunctionHandler{
BaseHandler: BaseHandler{
App: server,
@@ -31,6 +39,7 @@ func NewFunctionHandler(server *core.AppServer, db *gorm.DB, config *types.AppCo
},
config: config.ApiConfig,
uploadManager: manager,
dallService: dallService,
}
}
@@ -151,30 +160,6 @@ func (h *FunctionHandler) ZaoBao(c *gin.Context) {
resp.SUCCESS(c, strings.Join(builder, "\n\n"))
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
// Dall3 DallE3 AI 绘图
func (h *FunctionHandler) Dall3(c *gin.Context) {
if err := h.checkAuth(c); err != nil {
@@ -190,85 +175,40 @@ func (h *FunctionHandler) Dall3(c *gin.Context) {
logger.Debugf("绘画参数:%+v", params)
var user model.User
tx := h.DB.Where("id = ?", params["user_id"]).First(&user)
if tx.Error != nil {
res := h.DB.Where("id = ?", params["user_id"]).First(&user)
if res.Error != nil {
resp.ERROR(c, "当前用户不存在!")
return
}
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return
}
// create dall task
prompt := utils.InterfaceToString(params["prompt"])
// get image generation API KEY
var apiKey model.ApiKey
tx = h.DB.Where("platform = ?", types.OpenAI).Where("type = ?", "img").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
resp.ERROR(c, "获取绘图 API KEY 失败: "+tx.Error.Error())
job := model.DallJob{
UserId: user.Id,
Prompt: prompt,
Power: h.App.SysConfig.DallPower,
}
res = h.DB.Create(&job)
if res.Error != nil {
resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
return
}
// translate prompt
const translatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
pt, err := utils.OpenAIRequest(h.DB, fmt.Sprintf(translatePromptTemplate, params["prompt"]))
if err == nil {
logger.Debugf("翻译绘画提示词,原文:%s译文%s", prompt, pt)
prompt = pt
}
var res imgRes
var errRes ErrRes
var request *req.Request
if len(apiKey.ProxyURL) > 5 {
request = req.C().SetProxyURL(apiKey.ProxyURL).R()
} else {
request = req.C().R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := request.SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if r.IsErrorState() {
resp.ERROR(c, "请求 OpenAI API 失败: "+errRes.Error.Message)
return
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
logger.Debugf("%+v", res)
// 存储图片
imgURL, err := h.uploadManager.GetUploadHandler().PutImg(res.Data[0].Url, false)
content, err := h.dallService.Image(types.DallTask{
JobId: job.Id,
UserId: user.Id,
Prompt: job.Prompt,
N: 1,
Quality: "standard",
Size: "1024x1024",
Style: "vivid",
Power: job.Power,
}, true)
if err != nil {
resp.ERROR(c, "下载图片失败: "+err.Error())
resp.ERROR(c, "任务执行失败:"+err.Error())
return
}
content := fmt.Sprintf("下面是根据您的描述创作的图片,它描绘了 【%s】 的场景。 \n\n![](%s)\n", prompt, imgURL)
// 更新用户算力
tx = h.DB.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", h.App.SysConfig.DallPower))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
h.DB.Where("id", user.Id).First(&u)
h.DB.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: h.App.SysConfig.DallPower,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(prompt, 10)),
CreatedAt: time.Now(),
})
}
resp.SUCCESS(c, content)
}

View File

@@ -65,7 +65,7 @@ func (h *SdJobHandler) Client(c *gin.Context) {
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
}
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
@@ -88,7 +88,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
// Image 创建一个绘画任务
func (h *SdJobHandler) Image(c *gin.Context) {
if !h.checkLimits(c) {
if !h.preCheck(c) {
return
}
@@ -298,7 +298,7 @@ func (h *SdJobHandler) Remove(c *gin.Context) {
client := h.pool.Clients.Get(data.UserId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
_ = client.Send([]byte(sd.Finished))
}
resp.SUCCESS(c)

View File

@@ -8,6 +8,7 @@ import (
"chatplus/handler/chatimpl"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/dalle"
"chatplus/service/mj"
"chatplus/service/oss"
"chatplus/service/payment"
@@ -153,6 +154,12 @@ func main() {
}),
fx.Provide(oss.NewUploaderManager),
fx.Provide(mj.NewService),
fx.Provide(dalle.NewService),
fx.Invoke(func(service *dalle.Service) {
service.Run()
service.CheckTaskNotify()
service.DownloadImages()
}),
// 邮件服务
fx.Provide(service.NewSmtpService),
@@ -441,6 +448,16 @@ func main() {
group := s.Engine.Group("/api/markMap/")
group.Any("client", h.Client)
}),
fx.Provide(handler.NewDallJobHandler),
fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
group := s.Engine.Group("/api/dall")
group.Any("client", h.Client)
group.POST("image", h.Image)
group.GET("jobs", h.JobList)
group.GET("imgWall", h.ImgWall)
group.POST("remove", h.Remove)
group.POST("publish", h.Publish)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
go func() {
err := s.Run(db)

View File

@@ -0,0 +1,259 @@
package dalle
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/service"
"chatplus/service/oss"
"chatplus/service/sd"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
// DALL-E 绘画服务
type Service struct {
httpClient *req.Client
db *gorm.DB
uploadManager *oss.UploaderManager
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
}
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
return &Service{
httpClient: req.C().SetTimeout(time.Minute * 3),
db: db,
taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli),
Clients: types.NewLMap[uint, *types.WsClient](),
uploadManager: manager,
}
}
// PushTask push a new mj task in to task queue
func (s *Service) PushTask(task types.DallTask) {
logger.Debugf("add a new MidJourney task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
func (s *Service) Run() {
go func() {
for {
var task types.DallTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
_, err = s.Image(task, false)
if err != nil {
logger.Errorf("error with image task: %v", err)
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": -1,
"err_msg": err.Error(),
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
}
}
}()
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
Style string `json:"style"`
}
type imgRes struct {
Created int64 `json:"created"`
Data []struct {
RevisedPrompt string `json:"revised_prompt"`
Url string `json:"url"`
} `json:"data"`
}
type ErrRes struct {
Error struct {
Code interface{} `json:"code"`
Message string `json:"message"`
Param interface{} `json:"param"`
Type string `json:"type"`
} `json:"error"`
}
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
if err != nil {
return "", fmt.Errorf("error with translate prompt: %v", err)
}
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
var user model.User
s.db.Where("id", task.UserId).First(&user)
if user.Power < task.Power {
return "", errors.New("insufficient of power")
}
// get image generation API KEY
var apiKey model.ApiKey
tx := s.db.Where("platform", types.OpenAI).
Where("type", "img").
Where("enabled", true).
Order("last_used_at ASC").First(&apiKey)
if tx.Error != nil {
return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
}
var res imgRes
var errRes ErrRes
if len(apiKey.ProxyURL) > 5 {
s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
}
logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(imgReq{
Model: "dall-e-3",
Prompt: prompt,
N: 1,
Size: "1024x1024",
Style: task.Style,
Quality: task.Quality,
}).
SetErrorResult(&errRes).
SetSuccessResult(&res).Post(apiKey.ApiURL)
if err != nil {
return "", fmt.Errorf("error with send request: %v", err)
}
if r.IsErrorState() {
return "", fmt.Errorf("error with send request: %v", errRes.Error)
}
// update the api key last use time
s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// update task progress
s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
"progress": 100,
"org_url": res.Data[0].Url,
"prompt": prompt,
})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
var content string
if sync {
imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
}
// 更新用户算力
tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
// 记录算力变化日志
if tx.Error == nil && tx.RowsAffected > 0 {
var u model.User
s.db.Where("id", user.Id).First(&u)
s.db.Create(&model.PowerLog{
UserId: user.Id,
Username: user.Username,
Type: types.PowerConsume,
Amount: task.Power,
Balance: u.Power,
Mark: types.PowerSub,
Model: "dall-e-3",
Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
CreatedAt: time.Now(),
})
}
return content, nil
}
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running DALL-E task notify checking ...")
for {
var message sd.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
client := s.Clients.Get(uint(message.UserId))
if client == nil {
continue
}
err = client.Send([]byte(message.Message))
if err != nil {
continue
}
}
}()
}
func (s *Service) DownloadImages() {
go func() {
var items []model.DallJob
for {
res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
if res.Error != nil {
continue
}
// download images
for _, v := range items {
if v.OrgURL == "" {
continue
}
logger.Infof("try to download image: %s", v.OrgURL)
imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
if err != nil {
logger.Error("error with download image: %s, error: %v", imgURL, err)
continue
}
}
time.Sleep(time.Second * 5)
}
}()
}
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
// sava image
imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
if err != nil {
return "", err
}
// update img_url
res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
return imgURL, nil
}

View File

@@ -1,4 +1,4 @@
package service
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other elements. Please output directly in English without any explanation, within 150 words. The text to be rewritten is: [%s]"
const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]"
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"

View File

@@ -4,13 +4,13 @@ import "time"
type DallJob struct {
Id uint `gorm:"primarykey;column:id"`
UserId int
TaskId string
UserId uint
Prompt string
ImgURL string
Publish bool
Power int
Progress int
ErrMsg string
ImgURL string
OrgURL string
Publish bool
Power int
Progress int
ErrMsg string
CreatedAt time.Time
}

View File

@@ -1,14 +1,14 @@
package vo
type DallJob struct {
Id uint `json:"id"`
UserId int `json:"user_id"`
TaskId string `json:"task_id"`
Id uint `json:"id"`
UserId int `json:"user_id"`
Prompt string `json:"prompt"`
ImgURL string `json:"img_url"`
Publish bool `json:"publish"`
Power int `json:"power"`
Progress int `json:"progress"`
ErrMsg string `json:"err_msg"`
CreatedAt int64 `json:"created_at"`
ImgURL string `json:"img_url"`
OrgURL string `json:"org_url"`
Publish bool `json:"publish"`
Power int `json:"power"`
Progress int `json:"progress"`
ErrMsg string `json:"err_msg"`
CreatedAt int64 `json:"created_at"`
}