mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: 将操作拆分成单独的模型
This commit is contained in:
parent
d5ffaf2502
commit
3d10c9f090
@ -94,17 +94,30 @@ var ModelRatio = map[string]float64{
|
|||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
}
|
}
|
||||||
|
|
||||||
var ModelPrice = map[string]float64{
|
var DefaultModelPrice = map[string]float64{
|
||||||
"gpt-4-gizmo-*": 0.1,
|
"gpt-4-gizmo-*": 0.1,
|
||||||
"mj_imagine": 0.1,
|
"mj_imagine": 0.1,
|
||||||
"mj_variation": 0.1,
|
"mj_variation": 0.1,
|
||||||
"mj_reroll": 0.1,
|
"mj_reroll": 0.1,
|
||||||
"mj_blend": 0.1,
|
"mj_blend": 0.1,
|
||||||
"mj_describe": 0.05,
|
"mj_inpaint": 0.1,
|
||||||
"mj_upscale": 0.05,
|
"mj_zoom": 0.1,
|
||||||
|
"mj_shorten": 0.1,
|
||||||
|
"mj_high_variation": 0.1,
|
||||||
|
"mj_low_variation": 0.1,
|
||||||
|
"mj_pan": 0.1,
|
||||||
|
"mj_inpaint_pre": 0,
|
||||||
|
"mj_describe": 0.05,
|
||||||
|
"mj_upscale": 0.05,
|
||||||
|
"swap_face": 0.05,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ModelPrice = map[string]float64{}
|
||||||
|
|
||||||
func ModelPrice2JSONString() string {
|
func ModelPrice2JSONString() string {
|
||||||
|
if len(ModelPrice) == 0 {
|
||||||
|
ModelPrice = DefaultModelPrice
|
||||||
|
}
|
||||||
jsonBytes, err := json.Marshal(ModelPrice)
|
jsonBytes, err := json.Marshal(ModelPrice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling model price: " + err.Error())
|
SysError("error marshalling model price: " + err.Error())
|
||||||
@ -118,6 +131,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetModelPrice(name string, printErr bool) float64 {
|
func GetModelPrice(name string, printErr bool) float64 {
|
||||||
|
if len(ModelPrice) == 0 {
|
||||||
|
ModelPrice = DefaultModelPrice
|
||||||
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ const (
|
|||||||
MjActionBlend = "BLEND"
|
MjActionBlend = "BLEND"
|
||||||
MjActionUpscale = "UPSCALE"
|
MjActionUpscale = "UPSCALE"
|
||||||
MjActionVariation = "VARIATION"
|
MjActionVariation = "VARIATION"
|
||||||
|
MjActionReRoll = "REROLL"
|
||||||
MjActionInPaint = "INPAINT"
|
MjActionInPaint = "INPAINT"
|
||||||
MjActionInPaintPre = "INPAINT_PRE"
|
MjActionInPaintPre = "INPAINT_PRE"
|
||||||
MjActionZoom = "ZOOM"
|
MjActionZoom = "ZOOM"
|
||||||
@ -20,3 +21,20 @@ const (
|
|||||||
MjActionPan = "PAN"
|
MjActionPan = "PAN"
|
||||||
SwapFace = "SWAP_FACE"
|
SwapFace = "SWAP_FACE"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var MidjourneyModel2Action = map[string]string{
|
||||||
|
"mj_imagine": MjActionImagine,
|
||||||
|
"mj_describe": MjActionDescribe,
|
||||||
|
"mj_blend": MjActionBlend,
|
||||||
|
"mj_upscale": MjActionUpscale,
|
||||||
|
"mj_variation": MjActionVariation,
|
||||||
|
"mj_reroll": MjActionReRoll,
|
||||||
|
"mj_inpaint": MjActionInPaint,
|
||||||
|
"mj_inpaint_pre": MjActionInPaintPre,
|
||||||
|
"mj_zoom": MjActionZoom,
|
||||||
|
"mj_shorten": MjActionShorten,
|
||||||
|
"mj_high_variation": MjActionHighVariation,
|
||||||
|
"mj_low_variation": MjActionLowVariation,
|
||||||
|
"mj_pan": MjActionPan,
|
||||||
|
"swap_face": SwapFace,
|
||||||
|
}
|
||||||
|
@ -18,137 +18,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*func UpdateMidjourneyTask() {
|
|
||||||
//revocer
|
|
||||||
//imageModel := "midjourney"
|
|
||||||
ctx := context.TODO()
|
|
||||||
imageModel := "midjourney"
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Duration(15) * time.Second)
|
|
||||||
tasks := model.GetAllUnFinishTasks()
|
|
||||||
if len(tasks) != 0 {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
|
||||||
for _, task := range tasks {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
|
|
||||||
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
|
|
||||||
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
|
|
||||||
task.Status = "FAILURE"
|
|
||||||
task.Progress = "100%"
|
|
||||||
err := task.Update()
|
|
||||||
if err != nil {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
|
|
||||||
if err != nil {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置超时时间
|
|
||||||
timeout := time.Second * 5
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
|
|
||||||
// 使用带有超时的 context 创建新的请求
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
|
||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
log.Printf("responseBody: %s", string(responseBody))
|
|
||||||
var responseItem MidjourneyDto
|
|
||||||
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
|
|
||||||
err = json.Unmarshal(responseBody, &responseItem)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") {
|
|
||||||
var responseWithoutStatus MidjourneyWithoutStatus
|
|
||||||
var responseStatus MidjourneyStatus
|
|
||||||
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
|
|
||||||
err2 := json.Unmarshal(responseBody, &responseStatus)
|
|
||||||
if err1 == nil && err2 == nil {
|
|
||||||
jsonData, err3 := json.Marshal(responseWithoutStatus)
|
|
||||||
if err3 != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error1: %v", err3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err4 := json.Unmarshal(jsonData, &responseStatus)
|
|
||||||
if err4 != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error2: %v", err4)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
responseItem.Status = strconv.Itoa(responseStatus.Status)
|
|
||||||
} else {
|
|
||||||
log.Printf("UpdateMidjourneyTask error3: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Printf("UpdateMidjourneyTask error4: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
task.Code = 1
|
|
||||||
task.Progress = responseItem.Progress
|
|
||||||
task.PromptEn = responseItem.PromptEn
|
|
||||||
task.State = responseItem.State
|
|
||||||
task.SubmitTime = responseItem.SubmitTime
|
|
||||||
task.StartTime = responseItem.StartTime
|
|
||||||
task.FinishTime = responseItem.FinishTime
|
|
||||||
task.ImageUrl = responseItem.ImageUrl
|
|
||||||
task.Status = responseItem.Status
|
|
||||||
task.FailReason = responseItem.FailReason
|
|
||||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
|
||||||
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
|
|
||||||
task.Progress = "100%"
|
|
||||||
err = model.CacheUpdateUserQuota(task.UserId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("error update user quota cache: " + err.Error())
|
|
||||||
} else {
|
|
||||||
modelRatio := common.GetModelRatio(imageModel)
|
|
||||||
groupRatio := common.GetGroupRatio("default")
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
quota := int(ratio * 1 * 1000)
|
|
||||||
if quota != 0 {
|
|
||||||
err := model.IncreaseUserQuota(task.UserId, quota)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("fail to increase user quota")
|
|
||||||
}
|
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = task.Update()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("UpdateMidjourneyTask error5: %v", err)
|
|
||||||
}
|
|
||||||
log.Printf("UpdateMidjourneyTask success: %v", task)
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func UpdateMidjourneyTaskBulk() {
|
func UpdateMidjourneyTaskBulk() {
|
||||||
//imageModel := "midjourney"
|
//imageModel := "midjourney"
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
|
@ -4,12 +4,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/channel/ai360"
|
"one-api/relay/channel/ai360"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
"one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@ -59,8 +60,8 @@ func init() {
|
|||||||
IsBlocking: false,
|
IsBlocking: false,
|
||||||
})
|
})
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
for i := 0; i < constant.APITypeDummy; i++ {
|
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||||
if i == constant.APITypeAIProxyLibrary {
|
if i == relayconstant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
adaptor := relay.GetAdaptor(i)
|
adaptor := relay.GetAdaptor(i)
|
||||||
@ -100,6 +101,17 @@ func init() {
|
|||||||
Parent: nil,
|
Parent: nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||||
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: "midjourney",
|
||||||
|
Permission: permission,
|
||||||
|
Root: modelName,
|
||||||
|
Parent: nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
openAIModelsMap[model.Id] = model
|
openAIModelsMap[model.Id] = model
|
||||||
|
@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
relayMode := c.GetInt("relay_mode")
|
||||||
var err *dto.MidjourneyResponse
|
var err *dto.MidjourneyResponse
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
@ -73,13 +73,15 @@ func RelayMidjourney(c *gin.Context) {
|
|||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
statusCode := http.StatusBadRequest
|
||||||
if err.Code == 30 {
|
if err.Code == 30 {
|
||||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
|
statusCode = http.StatusTooManyRequests
|
||||||
}
|
}
|
||||||
c.JSON(429, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": fmt.Sprintf("%s %s", err.Description, err.Result),
|
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||||
"type": "upstream_error",
|
"type": "upstream_error",
|
||||||
"code": err.Code,
|
"code": err.Code,
|
||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
//type SimpleMjRequest struct {
|
||||||
|
// Prompt string `json:"prompt"`
|
||||||
|
// CustomId string `json:"customId"`
|
||||||
|
// Action string `json:"action"`
|
||||||
|
// Content string `json:"content"`
|
||||||
|
//}
|
||||||
|
|
||||||
type MidjourneyRequest struct {
|
type MidjourneyRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
CustomId string `json:"customId"`
|
CustomId string `json:"customId"`
|
||||||
|
@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled {
|
if !userEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
@ -129,7 +129,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -23,32 +27,58 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
shouldSelectChannel := true
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
var err error
|
var err error
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
||||||
// Midjourney
|
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||||
if modelRequest.Model == "" {
|
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||||
modelRequest.Model = "midjourney"
|
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||||
|
relayMode == relayconstant.RelayModeMidjourneyNotify {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
} else {
|
||||||
|
midjourneyRequest := dto.MidjourneyRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||||
|
if err != nil {
|
||||||
|
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||||
|
if mjErr != nil {
|
||||||
|
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if midjourneyModel == "" {
|
||||||
|
if !success {
|
||||||
|
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// task fetch, task fetch by condition, notify
|
||||||
|
shouldSelectChannel = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelRequest.Model = midjourneyModel
|
||||||
}
|
}
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
@ -87,60 +117,61 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if tokenModelLimit != nil {
|
if tokenModelLimit != nil {
|
||||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
||||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// token model limit is empty, all models are not allowed
|
// token model limit is empty, all models are not allowed
|
||||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set("group", userGroup)
|
||||||
|
if shouldSelectChannel {
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
|
}
|
||||||
|
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||||
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if channel == nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("channel_id", channel.Id)
|
||||||
|
c.Set("channel_name", channel.Name)
|
||||||
|
ban := true
|
||||||
|
// parse *int to bool
|
||||||
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
|
ban = false
|
||||||
|
}
|
||||||
|
c.Set("auto_ban", ban)
|
||||||
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
// TODO: api_version统一
|
||||||
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
//case common.ChannelTypeAIProxyLibrary:
|
||||||
|
// c.Set("library_id", channel.Other)
|
||||||
|
case common.ChannelTypeGemini:
|
||||||
|
c.Set("api_version", channel.Other)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
c.Set("plugin", channel.Other)
|
||||||
}
|
}
|
||||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if channel == nil {
|
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Set("channel", channel.Type)
|
|
||||||
c.Set("channel_id", channel.Id)
|
|
||||||
c.Set("channel_name", channel.Name)
|
|
||||||
ban := true
|
|
||||||
// parse *int to bool
|
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
|
||||||
ban = false
|
|
||||||
}
|
|
||||||
c.Set("auto_ban", ban)
|
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
// TODO: api_version统一
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
//case common.ChannelTypeAIProxyLibrary:
|
|
||||||
// c.Set("library_id", channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set("api_version", channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set("plugin", channel.Other)
|
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), message)
|
common.LogError(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"description": description,
|
||||||
|
"type": "new_api_error",
|
||||||
|
"code": code,
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
common.LogError(c.Request.Context(), description)
|
||||||
|
}
|
||||||
|
@ -21,23 +21,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DefaultModelPrice = map[string]float64{
|
|
||||||
"mj_imagine": 0.1,
|
|
||||||
"mj_variation": 0.1,
|
|
||||||
"mj_reroll": 0.1,
|
|
||||||
"mj_blend": 0.1,
|
|
||||||
"mj_inpaint": 0.1,
|
|
||||||
"mj_zoom": 0.1,
|
|
||||||
"mj_shorten": 0.1,
|
|
||||||
"mj_high_variation": 0.1,
|
|
||||||
"mj_low_variation": 0.1,
|
|
||||||
"mj_pan": 0.1,
|
|
||||||
"mj_inpaint_pre": 0,
|
|
||||||
"mj_describe": 0.05,
|
|
||||||
"mj_upscale": 0.05,
|
|
||||||
"swap_face": 0.05,
|
|
||||||
}
|
|
||||||
|
|
||||||
func RelayMidjourneyImage(c *gin.Context) {
|
func RelayMidjourneyImage(c *gin.Context) {
|
||||||
taskId := c.Param("id")
|
taskId := c.Param("id")
|
||||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||||
@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||||
imageModel := "midjourney"
|
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
//channelType := c.GetInt("channel")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||||
mjErr := coverPlusActionToNormalAction(&midjRequest)
|
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||||||
if mjErr != nil {
|
if mjErr != nil {
|
||||||
return mjErr
|
return mjErr
|
||||||
}
|
}
|
||||||
@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if midjRequest.Content == "" {
|
if midjRequest.Content == "" {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||||||
}
|
}
|
||||||
params := convertSimpleChangeParams(midjRequest.Content)
|
params := service.ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
if params == nil {
|
if params == nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
||||||
}
|
}
|
||||||
mjId = params.ID
|
mjId = params.TaskId
|
||||||
midjRequest.Action = params.Action
|
midjRequest.Action = params.Action
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
||||||
if midjRequest.MaskBase64 == "" {
|
if midjRequest.MaskBase64 == "" {
|
||||||
@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||||
}
|
}
|
||||||
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||||
|
}
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
c.Set("channel_id", originTask.ChannelId)
|
c.Set("channel_id", originTask.ChannelId)
|
||||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||||
}
|
}
|
||||||
midjRequest.Prompt = originTask.Prompt
|
midjRequest.Prompt = originTask.Prompt
|
||||||
|
|
||||||
if channelType == common.ChannelTypeMidjourneyPlus {
|
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||||||
// plus
|
// // plus
|
||||||
} else {
|
//} else {
|
||||||
// 普通版渠道
|
// // 普通版渠道
|
||||||
|
//
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
if midjRequest.Action == constant.MjActionInPaintPre {
|
if midjRequest.Action == constant.MjActionInPaintPre {
|
||||||
@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
//modelMapping := c.GetString("model_mapping")
|
||||||
isModelMapped := false
|
//isModelMapped := false
|
||||||
if modelMapping != "" {
|
//if modelMapping != "" {
|
||||||
modelMap := make(map[string]string)
|
// modelMap := make(map[string]string)
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
// err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
// //return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||||
return &dto.MidjourneyResponse{
|
// return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
// Code: 4,
|
||||||
Description: "unmarshal_model_mapping_failed",
|
// Description: "unmarshal_model_mapping_failed",
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
if modelMap[imageModel] != "" {
|
// if modelMap[imageModel] != "" {
|
||||||
imageModel = modelMap[imageModel]
|
// imageModel = modelMap[imageModel]
|
||||||
isModelMapped = true
|
// isModelMapped = true
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
//baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
|
|
||||||
if c.GetString("base_url") != "" {
|
baseURL := c.GetString("base_url")
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||||
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if isModelMapped {
|
//if isModelMapped {
|
||||||
jsonStr, err := json.Marshal(midjRequest)
|
// jsonStr, err := json.Marshal(midjRequest)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
// return &dto.MidjourneyResponse{
|
||||||
Code: 4,
|
// Code: 4,
|
||||||
Description: "marshal_text_request_failed",
|
// Description: "marshal_text_request_failed",
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
// requestBody = bytes.NewBuffer(jsonStr)
|
||||||
} else {
|
//} else {
|
||||||
requestBody = c.Request.Body
|
//}
|
||||||
}
|
requestBody = c.Request.Body
|
||||||
|
|
||||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||||
modelPrice := common.GetModelPrice(mjAction, true)
|
modelPrice := common.GetModelPrice(modelName, true)
|
||||||
// 如果没有配置价格,则使用默认价格
|
// 如果没有配置价格,则使用默认价格
|
||||||
if modelPrice == -1 {
|
if modelPrice == -1 {
|
||||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||||
if !ok {
|
if !ok {
|
||||||
modelPrice = 0.1
|
modelPrice = 0.1
|
||||||
} else {
|
} else {
|
||||||
@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
@ -558,85 +541,3 @@ type taskChangeParams struct {
|
|||||||
Action string
|
Action string
|
||||||
Index int
|
Index int
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
|
||||||
split := strings.Split(content, " ")
|
|
||||||
if len(split) != 2 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
action := strings.ToLower(split[1])
|
|
||||||
changeParams := &taskChangeParams{}
|
|
||||||
changeParams.ID = split[0]
|
|
||||||
|
|
||||||
if action[0] == 'u' {
|
|
||||||
changeParams.Action = "UPSCALE"
|
|
||||||
} else if action[0] == 'v' {
|
|
||||||
changeParams.Action = "VARIATION"
|
|
||||||
} else if action == "r" {
|
|
||||||
changeParams.Action = "REROLL"
|
|
||||||
return changeParams
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
index, err := strconv.Atoi(action[1:2])
|
|
||||||
if err != nil || index < 1 || index > 4 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
changeParams.Index = index
|
|
||||||
return changeParams
|
|
||||||
}
|
|
||||||
|
|
||||||
func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
|
||||||
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
|
||||||
customId := midjRequest.CustomId
|
|
||||||
if customId == "" {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
|
||||||
}
|
|
||||||
splits := strings.Split(customId, "::")
|
|
||||||
var action string
|
|
||||||
if splits[1] == "JOB" {
|
|
||||||
action = splits[2]
|
|
||||||
} else {
|
|
||||||
action = splits[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == "" {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
|
||||||
}
|
|
||||||
if strings.Contains(action, "upsample") {
|
|
||||||
index, err := strconv.Atoi(splits[3])
|
|
||||||
if err != nil {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
||||||
}
|
|
||||||
midjRequest.Index = index
|
|
||||||
midjRequest.Action = constant.MjActionUpscale
|
|
||||||
} else if strings.Contains(action, "variation") {
|
|
||||||
midjRequest.Index = 1
|
|
||||||
if action == "variation" {
|
|
||||||
index, err := strconv.Atoi(splits[3])
|
|
||||||
if err != nil {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
||||||
}
|
|
||||||
midjRequest.Index = index
|
|
||||||
midjRequest.Action = constant.MjActionVariation
|
|
||||||
} else if action == "low_variation" {
|
|
||||||
midjRequest.Action = constant.MjActionLowVariation
|
|
||||||
} else if action == "high_variation" {
|
|
||||||
midjRequest.Action = constant.MjActionHighVariation
|
|
||||||
}
|
|
||||||
} else if strings.Contains(action, "pan") {
|
|
||||||
midjRequest.Action = constant.MjActionPan
|
|
||||||
midjRequest.Index = 1
|
|
||||||
} else if action == "Outpaint" || action == "CustomZoom" {
|
|
||||||
midjRequest.Action = constant.MjActionZoom
|
|
||||||
midjRequest.Index = 1
|
|
||||||
} else if action == "Inpaint" {
|
|
||||||
midjRequest.Action = constant.MjActionInPaintPre
|
|
||||||
midjRequest.Index = 1
|
|
||||||
} else {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
135
service/midjourney.go
Normal file
135
service/midjourney.go
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CoverActionToModelName(mjAction string) string {
|
||||||
|
modelName := "mj_" + strings.ToLower(mjAction)
|
||||||
|
return modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
|
||||||
|
action := ""
|
||||||
|
if relayMode == relayconstant.RelayModeMidjourneyAction {
|
||||||
|
// plus request
|
||||||
|
err := CoverPlusActionToNormalAction(midjRequest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err, false
|
||||||
|
}
|
||||||
|
action = midjRequest.Action
|
||||||
|
} else {
|
||||||
|
switch relayMode {
|
||||||
|
case relayconstant.RelayModeMidjourneyImagine:
|
||||||
|
action = constant.MjActionImagine
|
||||||
|
case relayconstant.RelayModeMidjourneyDescribe:
|
||||||
|
action = constant.MjActionDescribe
|
||||||
|
case relayconstant.RelayModeMidjourneyBlend:
|
||||||
|
action = constant.MjActionBlend
|
||||||
|
case relayconstant.RelayModeMidjourneyShorten:
|
||||||
|
action = constant.MjActionShorten
|
||||||
|
case relayconstant.RelayModeMidjourneyChange:
|
||||||
|
action = midjRequest.Action
|
||||||
|
case relayconstant.RelayModeMidjourneyModal:
|
||||||
|
action = constant.MjActionInPaint
|
||||||
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
||||||
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
|
if params == nil {
|
||||||
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
|
||||||
|
}
|
||||||
|
action = params.Action
|
||||||
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
|
||||||
|
return "", nil, true
|
||||||
|
default:
|
||||||
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelName := CoverActionToModelName(action)
|
||||||
|
return modelName, nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
||||||
|
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
||||||
|
customId := midjRequest.CustomId
|
||||||
|
if customId == "" {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
||||||
|
}
|
||||||
|
splits := strings.Split(customId, "::")
|
||||||
|
var action string
|
||||||
|
if splits[1] == "JOB" {
|
||||||
|
action = splits[2]
|
||||||
|
} else {
|
||||||
|
action = splits[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if action == "" {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
||||||
|
}
|
||||||
|
if strings.Contains(action, "upsample") {
|
||||||
|
index, err := strconv.Atoi(splits[3])
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||||
|
}
|
||||||
|
midjRequest.Index = index
|
||||||
|
midjRequest.Action = constant.MjActionUpscale
|
||||||
|
} else if strings.Contains(action, "variation") {
|
||||||
|
midjRequest.Index = 1
|
||||||
|
if action == "variation" {
|
||||||
|
index, err := strconv.Atoi(splits[3])
|
||||||
|
if err != nil {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||||
|
}
|
||||||
|
midjRequest.Index = index
|
||||||
|
midjRequest.Action = constant.MjActionVariation
|
||||||
|
} else if action == "low_variation" {
|
||||||
|
midjRequest.Action = constant.MjActionLowVariation
|
||||||
|
} else if action == "high_variation" {
|
||||||
|
midjRequest.Action = constant.MjActionHighVariation
|
||||||
|
}
|
||||||
|
} else if strings.Contains(action, "pan") {
|
||||||
|
midjRequest.Action = constant.MjActionPan
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "Outpaint" || action == "CustomZoom" {
|
||||||
|
midjRequest.Action = constant.MjActionZoom
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else if action == "Inpaint" {
|
||||||
|
midjRequest.Action = constant.MjActionInPaintPre
|
||||||
|
midjRequest.Index = 1
|
||||||
|
} else {
|
||||||
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
|
||||||
|
split := strings.Split(content, " ")
|
||||||
|
if len(split) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
action := strings.ToLower(split[1])
|
||||||
|
changeParams := &dto.MidjourneyRequest{}
|
||||||
|
changeParams.TaskId = split[0]
|
||||||
|
|
||||||
|
if action[0] == 'u' {
|
||||||
|
changeParams.Action = "UPSCALE"
|
||||||
|
} else if action[0] == 'v' {
|
||||||
|
changeParams.Action = "VARIATION"
|
||||||
|
} else if action == "r" {
|
||||||
|
changeParams.Action = "REROLL"
|
||||||
|
return changeParams
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err := strconv.Atoi(action[1:2])
|
||||||
|
if err != nil || index < 1 || index > 4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
changeParams.Index = index
|
||||||
|
return changeParams
|
||||||
|
}
|
@ -96,10 +96,25 @@ const EditChannel = (props) => {
|
|||||||
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
|
localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
localModels = ['midjourney'];
|
localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe'];
|
||||||
break;
|
break;
|
||||||
case 5:
|
case 5:
|
||||||
localModels = ['midjourney'];
|
localModels = [
|
||||||
|
'swap_face',
|
||||||
|
'mj_imagine',
|
||||||
|
'mj_variation',
|
||||||
|
'mj_reroll',
|
||||||
|
'mj_blend',
|
||||||
|
'mj_upscale',
|
||||||
|
'mj_describe',
|
||||||
|
'mj_zoom',
|
||||||
|
'mj_shorten',
|
||||||
|
'mj_inpaint_pre',
|
||||||
|
'mj_inpaint_pre',
|
||||||
|
'mj_high_variation',
|
||||||
|
'mj_low_variation',
|
||||||
|
'mj_pan',
|
||||||
|
];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
setInputs((inputs) => ({...inputs, models: localModels}));
|
setInputs((inputs) => ({...inputs, models: localModels}));
|
||||||
|
Loading…
Reference in New Issue
Block a user