mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-10 08:03:41 +08:00
feat: 将操作拆分成单独的模型
This commit is contained in:
@@ -21,23 +21,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_inpaint": 0.1,
|
||||
"mj_zoom": 0.1,
|
||||
"mj_shorten": 0.1,
|
||||
"mj_high_variation": 0.1,
|
||||
"mj_low_variation": 0.1,
|
||||
"mj_pan": 0.1,
|
||||
"mj_inpaint_pre": 0,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
"swap_face": 0.05,
|
||||
}
|
||||
|
||||
func RelayMidjourneyImage(c *gin.Context) {
|
||||
taskId := c.Param("id")
|
||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||
@@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
||||
}
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||
imageModel := "midjourney"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
//channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
@@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||
mjErr := coverPlusActionToNormalAction(&midjRequest)
|
||||
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||||
if mjErr != nil {
|
||||
return mjErr
|
||||
}
|
||||
@@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if midjRequest.Content == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||||
}
|
||||
params := convertSimpleChangeParams(midjRequest.Content)
|
||||
params := service.ConvertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
||||
}
|
||||
mjId = params.ID
|
||||
mjId = params.TaskId
|
||||
midjRequest.Action = params.Action
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
||||
if midjRequest.MaskBase64 == "" {
|
||||
@@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
|
||||
if channelType == common.ChannelTypeMidjourneyPlus {
|
||||
// plus
|
||||
} else {
|
||||
// 普通版渠道
|
||||
|
||||
}
|
||||
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||||
// // plus
|
||||
//} else {
|
||||
// // 普通版渠道
|
||||
//
|
||||
//}
|
||||
}
|
||||
|
||||
if midjRequest.Action == constant.MjActionInPaintPre {
|
||||
@@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_model_mapping_failed",
|
||||
}
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
//modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
//if modelMapping != "" {
|
||||
// modelMap := make(map[string]string)
|
||||
// err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
// if err != nil {
|
||||
// //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
// return &dto.MidjourneyResponse{
|
||||
// Code: 4,
|
||||
// Description: "unmarshal_model_mapping_failed",
|
||||
// }
|
||||
// }
|
||||
// if modelMap[imageModel] != "" {
|
||||
// imageModel = modelMap[imageModel]
|
||||
// isModelMapped = true
|
||||
// }
|
||||
//}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
//baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
baseURL := c.GetString("base_url")
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(midjRequest)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "marshal_text_request_failed",
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
//if isModelMapped {
|
||||
// jsonStr, err := json.Marshal(midjRequest)
|
||||
// if err != nil {
|
||||
// return &dto.MidjourneyResponse{
|
||||
// Code: 4,
|
||||
// Description: "marshal_text_request_failed",
|
||||
// }
|
||||
// }
|
||||
// requestBody = bytes.NewBuffer(jsonStr)
|
||||
//} else {
|
||||
//}
|
||||
requestBody = c.Request.Body
|
||||
|
||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(mjAction, true)
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
||||
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
@@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
@@ -558,85 +541,3 @@ type taskChangeParams struct {
|
||||
Action string
|
||||
Index int
|
||||
}
|
||||
|
||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &taskChangeParams{}
|
||||
changeParams.ID = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
||||
|
||||
func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
||||
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
||||
customId := midjRequest.CustomId
|
||||
if customId == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
||||
}
|
||||
splits := strings.Split(customId, "::")
|
||||
var action string
|
||||
if splits[1] == "JOB" {
|
||||
action = splits[2]
|
||||
} else {
|
||||
action = splits[1]
|
||||
}
|
||||
|
||||
if action == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
||||
}
|
||||
if strings.Contains(action, "upsample") {
|
||||
index, err := strconv.Atoi(splits[3])
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||
}
|
||||
midjRequest.Index = index
|
||||
midjRequest.Action = constant.MjActionUpscale
|
||||
} else if strings.Contains(action, "variation") {
|
||||
midjRequest.Index = 1
|
||||
if action == "variation" {
|
||||
index, err := strconv.Atoi(splits[3])
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||
}
|
||||
midjRequest.Index = index
|
||||
midjRequest.Action = constant.MjActionVariation
|
||||
} else if action == "low_variation" {
|
||||
midjRequest.Action = constant.MjActionLowVariation
|
||||
} else if action == "high_variation" {
|
||||
midjRequest.Action = constant.MjActionHighVariation
|
||||
}
|
||||
} else if strings.Contains(action, "pan") {
|
||||
midjRequest.Action = constant.MjActionPan
|
||||
midjRequest.Index = 1
|
||||
} else if action == "Outpaint" || action == "CustomZoom" {
|
||||
midjRequest.Action = constant.MjActionZoom
|
||||
midjRequest.Index = 1
|
||||
} else if action == "Inpaint" {
|
||||
midjRequest.Action = constant.MjActionInPaintPre
|
||||
midjRequest.Index = 1
|
||||
} else {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user