merge v4.0.6

This commit is contained in:
RockYang
2024-06-03 14:22:08 +08:00
53 changed files with 4872 additions and 1163 deletions

View File

@@ -175,4 +175,6 @@ type SystemConfig struct {
EnableContext bool `json:"enable_context,omitempty"`
ContextDeep int `json:"context_deep,omitempty"`
SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
}

View File

@@ -55,9 +55,10 @@ type SdTaskParams struct {
NegPrompt string `json:"neg_prompt"` // 反向提示词
Steps int `json:"steps"` // 迭代步数默认20
Sampler string `json:"sampler"` // 采样器
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Scheduler string `json:"scheduler"`
FaceFix bool `json:"face_fix"` // 面部修复
CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7
Seed int64 `json:"seed"` // 随机数种子
Height int `json:"height"`
Width int `json:"width"`
HdFix bool `json:"hd_fix"` // 启用高清修复

View File

@@ -15,6 +15,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)

View File

@@ -144,18 +144,13 @@ func (h *ChatRoleHandler) Set(c *gin.Context) {
}
func (h *ChatRoleHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.DB.Where("id = ?", data.Id).Delete(&model.ChatRole{})
res := h.DB.Where("id", id).Delete(&model.ChatRole{})
if res.Error != nil {
resp.ERROR(c, "删除失败!")
return

View File

@@ -13,6 +13,7 @@ import (
"geekai/store/vo"
"geekai/utils"
"geekai/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -32,7 +33,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
var res *gorm.DB
// 如果用户没有登录,则加载所有开放模型
if !h.IsLogin(c) {
res = h.DB.Where("enabled = ?", true).Where("open =?", true).Order("sort_num ASC").Find(&items)
res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
} else {
user, _ := h.GetLoginUser(c)
var models []int
@@ -43,7 +44,7 @@ func (h *ChatModelHandler) List(c *gin.Context) {
}
// 查询用户有权限访问的模型以及所有开放的模型
res = h.DB.Where("enabled = ?", true).Where(
h.DB.Where("id IN ?", models).Or("open =?", true),
h.DB.Where("id IN ?", models).Or("open", true),
).Order("sort_num ASC").Find(&items)
}

View File

@@ -9,13 +9,14 @@ package chatimpl
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"strings"
@@ -44,14 +45,9 @@ func (h *ChatHandler) sendAzureMessage(
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()
@@ -73,10 +69,7 @@ func (h *ChatHandler) sendAzureMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
return errors.New(line)
}
if len(responseBody.Choices) == 0 {
@@ -203,11 +196,10 @@ func (h *ChatHandler) sendAzureMessage(
if strings.Contains(res.Error.Message, "maximum context length") {
logger.Error(res.Error.Message)
utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
h.App.ChatContexts.Delete(session.ChatId)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
return fmt.Errorf("请求 Azure API 失败:%v", res.Error)
}
}

View File

@@ -9,13 +9,14 @@ package chatimpl
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"net/http"
@@ -61,14 +62,8 @@ func (h *ChatHandler) sendBaiduMessage(
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()

View File

@@ -173,7 +173,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil {
logger.Error(err)
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
utils.ReplyMessage(client, err.Error())
} else {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
logger.Infof("回答完毕: %v", message.Content)
@@ -195,8 +195,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
var user model.User
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "未授权用户,您正在进行非法操作!")
return res.Error
return errors.New("未授权用户,您正在进行非法操作!")
}
var userVo vo.User
err := utils.CopyObject(user, &userVo)
@@ -206,28 +205,22 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
if userVo.Status == false {
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
}
if userVo.Power < session.Model.Power {
utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余算力(%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power))
utils.ReplyMessage(ws, ErrImg)
return nil
return fmt.Errorf("您当前剩余算力(%d已不足以支付当前模型的单次对话需要消耗的算力%d", userVo.Power, session.Model.Power)
}
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
return errors.New("您的账号已经过期,请联系管理员!")
}
// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
if promptTokens > session.Model.MaxContext {
utils.ReplyMessage(ws, "对话内容超出了当前模型允许的最大上下文长度!")
return nil
return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
}
var req = types.ApiRequest{
@@ -286,9 +279,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
break
default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, ErrImg)
return nil
return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
}
// 加载聊天上下文
@@ -402,10 +393,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
case types.QWen:
return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform),
})
return nil
}

View File

@@ -9,13 +9,14 @@ package chatimpl
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt/v5"
"html/template"
"io"
@@ -45,14 +46,8 @@ func (h *ChatHandler) sendChatGLMMessage(
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()

View File

@@ -11,6 +11,7 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"geekai/core/types"
"geekai/store/model"
@@ -76,10 +77,7 @@ func (h *ChatHandler) sendOpenAiMessage(
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil { // 数据解析出错
logger.Error(err, line)
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
break
return errors.New(line)
}
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue

View File

@@ -9,13 +9,14 @@ package chatimpl
import (
"bufio"
"context"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store/model"
"geekai/store/vo"
"geekai/utils"
"context"
"encoding/json"
"fmt"
"github.com/syndtr/goleveldb/leveldb/errors"
"html/template"
"io"
"strings"
@@ -59,14 +60,8 @@ func (h *ChatHandler) sendQWenMessage(
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
} else {
logger.Error(err)
return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
}
utils.ReplyMessage(ws, ErrorMsg)
utils.ReplyMessage(ws, ErrImg)
return err
} else {
defer response.Body.Close()

View File

@@ -83,7 +83,7 @@ func (h *ChatHandler) sendXunFeiMessage(
res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
}
// use the last unused key
if res.Error != nil {
if apiKey.Id == 0 {
res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
}
if res.Error != nil {

View File

@@ -92,8 +92,7 @@ func (h *DallJobHandler) preCheck(c *gin.Context) bool {
resp.NotAuth(c)
return false
}
if user.Power < h.App.SysConfig.SdPower {
if user.Power < h.App.SysConfig.DallPower {
resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
return false
}

View File

@@ -148,7 +148,6 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
contentType := response.Header.Get("Content-Type")
if strings.Contains(contentType, "text/event-stream") {
// 循环读取 Chunk 消息
var message = types.Message{}
scanner := bufio.NewScanner(response.Body)
var isNew = true
for scanner.Scan() {
@@ -159,26 +158,26 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
var responseBody = types.ApiResponse{}
err = json.Unmarshal([]byte(line[6:]), &responseBody)
if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
return fmt.Errorf("error with decode data: %v", err)
if err != nil { // 数据解析出错
return fmt.Errorf("error with decode data: %v", line)
}
// 初始化 role
if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
message.Role = responseBody.Choices[0].Delta.Role
if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
continue
} else if responseBody.Choices[0].FinishReason != "" {
break // 输出完成或者输出中断了
} else {
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
}
if responseBody.Choices[0].FinishReason == "stop" {
break
}
if isNew {
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
isNew = false
}
utils.ReplyChunkMessage(client, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
})
} // end for
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
@@ -206,6 +205,25 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
}
}
// 扣减算力
res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
if res.Error == nil {
// 记录算力消费日志
var u model.User
h.DB.Where("id", userId).First(&u)
h.DB.Create(&model.PowerLog{
UserId: u.Id,
Username: u.Username,
Type: types.PowerConsume,
Amount: chatModel.Power,
Mark: types.PowerSub,
Balance: u.Power,
Model: chatModel.Value,
Remark: fmt.Sprintf("AI绘制思维导图模型名称%s, ", chatModel.Value),
CreatedAt: time.Now(),
})
}
return nil
}

View File

@@ -171,7 +171,6 @@ func main() {
// 邮件服务
fx.Provide(service.NewSmtpService),
// License 服务
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
@@ -186,7 +185,8 @@ func main() {
// MidJourney service pool
fx.Provide(mj.NewServicePool),
fx.Invoke(func(pool *mj.ServicePool) {
fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
if pool.HasAvailableService() {
pool.DownloadImages()
pool.CheckTaskNotify()
@@ -196,7 +196,8 @@ func main() {
// Stable Diffusion 机器人
fx.Provide(sd.NewServicePool),
fx.Invoke(func(pool *sd.ServicePool) {
fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
pool.InitServices(config.SdConfigs)
if pool.HasAvailableService() {
pool.CheckTaskNotify()
pool.CheckTaskStatus()
@@ -329,7 +330,7 @@ func main() {
group.POST("save", h.Save)
group.POST("sort", h.Sort)
group.POST("set", h.Set)
group.POST("remove", h.Remove)
group.GET("remove", h.Remove)
}),
fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
group := s.Engine.Group("/api/admin/reward/")

View File

@@ -8,6 +8,8 @@ package dalle
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"geekai/service"
@@ -16,8 +18,6 @@ import (
"geekai/store"
"geekai/store/model"
"geekai/utils"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"time"
@@ -261,7 +261,7 @@ func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string,
if res.Error != nil {
return "", err
}
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Failed})
s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
return imgURL, nil
}
@@ -294,7 +294,7 @@ func (s *Service) CheckTaskStatus() {
Balance: user.Power + job.Power,
Mark: types.PowerAdd,
Model: "dall-e-3",
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
Remark: fmt.Sprintf("任务失败退回算力。任务ID%d", job.Id),
CreatedAt: time.Now(),
})
}

View File

@@ -16,6 +16,7 @@ import (
"geekai/store"
"geekai/store/model"
"github.com/go-redis/redis/v8"
"strings"
"time"
"gorm.io/gorm"
@@ -33,38 +34,10 @@ type ServicePool struct {
var logger = logger2.GetLogger()
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
for k, config := range appConfig.MjPlusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
plusService.Run()
}()
services = append(services, plusService)
}
for k, config := range appConfig.MjProxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
proxyService := NewService(name, taskQueue, notifyQueue, db, cli)
go func() {
proxyService.Run()
}()
services = append(services, proxyService)
}
return &ServicePool{
taskQueue: taskQueue,
notifyQueue: notifyQueue,
@@ -75,6 +48,42 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
}
}
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range plusConfigs {
if config.Enabled == false {
continue
}
cli := NewPlusClient(config)
name := fmt.Sprintf("mj-plus-service-%d", k)
plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
plusService.Run()
}()
p.services = append(p.services, plusService)
}
// for mid-journey proxy
for k, config := range proxyConfigs {
if config.Enabled == false {
continue
}
cli := NewProxyClient(config)
name := fmt.Sprintf("mj-proxy-service-%d", k)
proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
go func() {
proxyService.Run()
}()
p.services = append(p.services, proxyService)
}
}
func (p *ServicePool) CheckTaskNotify() {
go func() {
for {
@@ -111,17 +120,23 @@ func (p *ServicePool) DownloadImages() {
}
logger.Infof("try to download image: %s", v.OrgURL)
var imgURL string
var err error
if servicePlus := p.getService(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, false)
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
mjService := p.getService(v.ChannelId)
if mjService == nil {
logger.Errorf("Invalid task: %+v", v)
continue
}
task, _ := mjService.Client.QueryTask(v.TaskId)
if len(task.Buttons) > 0 {
v.Hash = GetImageHash(task.Buttons[0].CustomId)
}
// 如果是返回的是 discord 图片地址,则使用代理下载
proxy := false
if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
proxy = true
}
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
if err != nil {
logger.Errorf("error with download image %s, %v", v.OrgURL, err)
continue

View File

@@ -28,6 +28,7 @@ type Service struct {
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
running bool
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
@@ -37,12 +38,13 @@ func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.Red
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: cli,
running: true,
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.Name)
for {
for s.running {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
@@ -125,6 +127,10 @@ func (s *Service) Run() {
}
}
func (s *Service) Stop() {
s.running = false
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`

View File

@@ -25,28 +25,14 @@ type ServicePool struct {
notifyQueue *store.RedisQueue
db *gorm.DB
Clients *types.LMap[uint, *types.WsClient] // UserId => Client
uploader *oss.UploaderManager
levelDB *store.LevelDB
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig, levelDB *store.LevelDB) *ServicePool {
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool {
services := make([]*Service, 0)
taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli)
// create mj client and service
for _, config := range appConfig.SdConfigs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf("StableDifffusion Service-%s", config.Model)
service := NewService(name, config, taskQueue, notifyQueue, db, manager, levelDB)
// run sd service
go func() {
service.Run()
}()
services = append(services, service)
}
return &ServicePool{
taskQueue: taskQueue,
@@ -54,6 +40,32 @@ func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderMa
services: services,
db: db,
Clients: types.NewLMap[uint, *types.WsClient](),
uploader: manager,
levelDB: levelDB,
}
}
func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) {
// stop old service
for _, s := range p.services {
s.Stop()
}
p.services = make([]*Service, 0)
for k, config := range configs {
if config.Enabled == false {
continue
}
// create sd service
name := fmt.Sprintf(" sd-service-%d", k)
service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB)
// run sd service
go func() {
service.Run()
}()
p.services = append(p.services, service)
}
}

View File

@@ -33,6 +33,7 @@ type Service struct {
uploadManager *oss.UploaderManager
name string // service name
leveldb *store.LevelDB
running bool // 运行状态
}
func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service {
@@ -46,11 +47,13 @@ func NewService(name string, config types.StableDiffusionConfig, taskQueue *stor
db: db,
leveldb: levelDB,
uploadManager: manager,
running: true,
}
}
func (s *Service) Run() {
for {
logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name)
for s.running {
var task types.SdTask
err := s.taskQueue.LPop(&task)
if err != nil {
@@ -94,6 +97,10 @@ func (s *Service) Run() {
}
}
func (s *Service) Stop() {
s.running = false
}
// Txt2ImgReq 文生图请求实体
type Txt2ImgReq struct {
Prompt string `json:"prompt"`
@@ -104,6 +111,7 @@ type Txt2ImgReq struct {
Width int `json:"width"`
Height int `json:"height"`
SamplerName string `json:"sampler_name"`
Scheduler string `json:"scheduler"`
EnableHr bool `json:"enable_hr,omitempty"`
HrScale int `json:"hr_scale,omitempty"`
HrUpscaler string `json:"hr_upscaler,omitempty"`
@@ -137,6 +145,7 @@ func (s *Service) Txt2Img(task types.SdTask) error {
Width: task.Params.Width,
Height: task.Params.Height,
SamplerName: task.Params.Sampler,
Scheduler: task.Params.Scheduler,
ForceTaskId: task.Params.TaskId,
}
if task.Params.Seed > 0 {