mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: 将操作拆分成单独的模型
This commit is contained in:
		@@ -94,17 +94,30 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"hunyuan":                   7.143,  // ¥0.1 / 1k tokens  // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ModelPrice = map[string]float64{
 | 
			
		||||
	"gpt-4-gizmo-*": 0.1,
 | 
			
		||||
	"mj_imagine":    0.1,
 | 
			
		||||
	"mj_variation":  0.1,
 | 
			
		||||
	"mj_reroll":     0.1,
 | 
			
		||||
	"mj_blend":      0.1,
 | 
			
		||||
	"mj_describe":   0.05,
 | 
			
		||||
	"mj_upscale":    0.05,
 | 
			
		||||
var DefaultModelPrice = map[string]float64{
 | 
			
		||||
	"gpt-4-gizmo-*":     0.1,
 | 
			
		||||
	"mj_imagine":        0.1,
 | 
			
		||||
	"mj_variation":      0.1,
 | 
			
		||||
	"mj_reroll":         0.1,
 | 
			
		||||
	"mj_blend":          0.1,
 | 
			
		||||
	"mj_inpaint":        0.1,
 | 
			
		||||
	"mj_zoom":           0.1,
 | 
			
		||||
	"mj_shorten":        0.1,
 | 
			
		||||
	"mj_high_variation": 0.1,
 | 
			
		||||
	"mj_low_variation":  0.1,
 | 
			
		||||
	"mj_pan":            0.1,
 | 
			
		||||
	"mj_inpaint_pre":    0,
 | 
			
		||||
	"mj_describe":       0.05,
 | 
			
		||||
	"mj_upscale":        0.05,
 | 
			
		||||
	"swap_face":         0.05,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ModelPrice = map[string]float64{}
 | 
			
		||||
 | 
			
		||||
func ModelPrice2JSONString() string {
 | 
			
		||||
	if len(ModelPrice) == 0 {
 | 
			
		||||
		ModelPrice = DefaultModelPrice
 | 
			
		||||
	}
 | 
			
		||||
	jsonBytes, err := json.Marshal(ModelPrice)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		SysError("error marshalling model price: " + err.Error())
 | 
			
		||||
@@ -118,6 +131,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetModelPrice(name string, printErr bool) float64 {
 | 
			
		||||
	if len(ModelPrice) == 0 {
 | 
			
		||||
		ModelPrice = DefaultModelPrice
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4-gizmo") {
 | 
			
		||||
		name = "gpt-4-gizmo-*"
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@ const (
 | 
			
		||||
	MjActionBlend         = "BLEND"
 | 
			
		||||
	MjActionUpscale       = "UPSCALE"
 | 
			
		||||
	MjActionVariation     = "VARIATION"
 | 
			
		||||
	MjActionReRoll        = "REROLL"
 | 
			
		||||
	MjActionInPaint       = "INPAINT"
 | 
			
		||||
	MjActionInPaintPre    = "INPAINT_PRE"
 | 
			
		||||
	MjActionZoom          = "ZOOM"
 | 
			
		||||
@@ -20,3 +21,20 @@ const (
 | 
			
		||||
	MjActionPan           = "PAN"
 | 
			
		||||
	SwapFace              = "SWAP_FACE"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var MidjourneyModel2Action = map[string]string{
 | 
			
		||||
	"mj_imagine":        MjActionImagine,
 | 
			
		||||
	"mj_describe":       MjActionDescribe,
 | 
			
		||||
	"mj_blend":          MjActionBlend,
 | 
			
		||||
	"mj_upscale":        MjActionUpscale,
 | 
			
		||||
	"mj_variation":      MjActionVariation,
 | 
			
		||||
	"mj_reroll":         MjActionReRoll,
 | 
			
		||||
	"mj_inpaint":        MjActionInPaint,
 | 
			
		||||
	"mj_inpaint_pre":    MjActionInPaintPre,
 | 
			
		||||
	"mj_zoom":           MjActionZoom,
 | 
			
		||||
	"mj_shorten":        MjActionShorten,
 | 
			
		||||
	"mj_high_variation": MjActionHighVariation,
 | 
			
		||||
	"mj_low_variation":  MjActionLowVariation,
 | 
			
		||||
	"mj_pan":            MjActionPan,
 | 
			
		||||
	"swap_face":         SwapFace,
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -18,137 +18,6 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/*func UpdateMidjourneyTask() {
 | 
			
		||||
	//revocer
 | 
			
		||||
	//imageModel := "midjourney"
 | 
			
		||||
	ctx := context.TODO()
 | 
			
		||||
	imageModel := "midjourney"
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			log.Printf("UpdateMidjourneyTask panic: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Duration(15) * time.Second)
 | 
			
		||||
		tasks := model.GetAllUnFinishTasks()
 | 
			
		||||
		if len(tasks) != 0 {
 | 
			
		||||
			common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
 | 
			
		||||
			for _, task := range tasks {
 | 
			
		||||
				common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
 | 
			
		||||
				midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
 | 
			
		||||
					task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
 | 
			
		||||
					task.Status = "FAILURE"
 | 
			
		||||
					task.Progress = "100%"
 | 
			
		||||
					err := task.Update()
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
 | 
			
		||||
				common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
 | 
			
		||||
 | 
			
		||||
				req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 设置超时时间
 | 
			
		||||
				timeout := time.Second * 5
 | 
			
		||||
				ctx, cancel := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
 | 
			
		||||
				// 使用带有超时的 context 创建新的请求
 | 
			
		||||
				req = req.WithContext(ctx)
 | 
			
		||||
 | 
			
		||||
				req.Header.Set("Content-Type", "application/json")
 | 
			
		||||
				//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
 | 
			
		||||
				req.Header.Set("mj-api-secret", midjourneyChannel.Key)
 | 
			
		||||
				resp, err := httpClient.Do(req)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Printf("UpdateMidjourneyTask error: %v", err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
				resp.Body.Close()
 | 
			
		||||
				log.Printf("responseBody: %s", string(responseBody))
 | 
			
		||||
				var responseItem MidjourneyDto
 | 
			
		||||
				// err = json.NewDecoder(resp.Body).Decode(&responseItem)
 | 
			
		||||
				err = json.Unmarshal(responseBody, &responseItem)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field MidjourneyDto.status of type string") {
 | 
			
		||||
						var responseWithoutStatus MidjourneyWithoutStatus
 | 
			
		||||
						var responseStatus MidjourneyStatus
 | 
			
		||||
						err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
 | 
			
		||||
						err2 := json.Unmarshal(responseBody, &responseStatus)
 | 
			
		||||
						if err1 == nil && err2 == nil {
 | 
			
		||||
							jsonData, err3 := json.Marshal(responseWithoutStatus)
 | 
			
		||||
							if err3 != nil {
 | 
			
		||||
								log.Printf("UpdateMidjourneyTask error1: %v", err3)
 | 
			
		||||
								continue
 | 
			
		||||
							}
 | 
			
		||||
							err4 := json.Unmarshal(jsonData, &responseStatus)
 | 
			
		||||
							if err4 != nil {
 | 
			
		||||
								log.Printf("UpdateMidjourneyTask error2: %v", err4)
 | 
			
		||||
								continue
 | 
			
		||||
							}
 | 
			
		||||
							responseItem.Status = strconv.Itoa(responseStatus.Status)
 | 
			
		||||
						} else {
 | 
			
		||||
							log.Printf("UpdateMidjourneyTask error3: %v", err)
 | 
			
		||||
							continue
 | 
			
		||||
						}
 | 
			
		||||
					} else {
 | 
			
		||||
						log.Printf("UpdateMidjourneyTask error4: %v", err)
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				task.Code = 1
 | 
			
		||||
				task.Progress = responseItem.Progress
 | 
			
		||||
				task.PromptEn = responseItem.PromptEn
 | 
			
		||||
				task.State = responseItem.State
 | 
			
		||||
				task.SubmitTime = responseItem.SubmitTime
 | 
			
		||||
				task.StartTime = responseItem.StartTime
 | 
			
		||||
				task.FinishTime = responseItem.FinishTime
 | 
			
		||||
				task.ImageUrl = responseItem.ImageUrl
 | 
			
		||||
				task.Status = responseItem.Status
 | 
			
		||||
				task.FailReason = responseItem.FailReason
 | 
			
		||||
				if task.Progress != "100%" && responseItem.FailReason != "" {
 | 
			
		||||
					common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
 | 
			
		||||
					task.Progress = "100%"
 | 
			
		||||
					err = model.CacheUpdateUserQuota(task.UserId)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						log.Println("error update user quota cache: " + err.Error())
 | 
			
		||||
					} else {
 | 
			
		||||
						modelRatio := common.GetModelRatio(imageModel)
 | 
			
		||||
						groupRatio := common.GetGroupRatio("default")
 | 
			
		||||
						ratio := modelRatio * groupRatio
 | 
			
		||||
						quota := int(ratio * 1 * 1000)
 | 
			
		||||
						if quota != 0 {
 | 
			
		||||
							err := model.IncreaseUserQuota(task.UserId, quota)
 | 
			
		||||
							if err != nil {
 | 
			
		||||
								log.Println("fail to increase user quota")
 | 
			
		||||
							}
 | 
			
		||||
							logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
 | 
			
		||||
							model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				err = task.Update()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Printf("UpdateMidjourneyTask error5: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
				log.Printf("UpdateMidjourneyTask success: %v", task)
 | 
			
		||||
				cancel()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
func UpdateMidjourneyTaskBulk() {
 | 
			
		||||
	//imageModel := "midjourney"
 | 
			
		||||
	ctx := context.TODO()
 | 
			
		||||
 
 | 
			
		||||
@@ -4,12 +4,13 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"one-api/relay"
 | 
			
		||||
	"one-api/relay/channel/ai360"
 | 
			
		||||
	"one-api/relay/channel/moonshot"
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/models/list
 | 
			
		||||
@@ -59,8 +60,8 @@ func init() {
 | 
			
		||||
		IsBlocking:         false,
 | 
			
		||||
	})
 | 
			
		||||
	// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
	for i := 0; i < constant.APITypeDummy; i++ {
 | 
			
		||||
		if i == constant.APITypeAIProxyLibrary {
 | 
			
		||||
	for i := 0; i < relayconstant.APITypeDummy; i++ {
 | 
			
		||||
		if i == relayconstant.APITypeAIProxyLibrary {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		adaptor := relay.GetAdaptor(i)
 | 
			
		||||
@@ -100,6 +101,17 @@ func init() {
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	for modelName, _ := range constant.MidjourneyModel2Action {
 | 
			
		||||
		openAIModels = append(openAIModels, OpenAIModels{
 | 
			
		||||
			Id:         modelName,
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1626777600,
 | 
			
		||||
			OwnedBy:    "midjourney",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       modelName,
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
		openAIModelsMap[model.Id] = model
 | 
			
		||||
 
 | 
			
		||||
@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayMidjourney(c *gin.Context) {
 | 
			
		||||
	relayMode := constant.Path2RelayModeMidjourney(c.Request.URL.Path)
 | 
			
		||||
	relayMode := c.GetInt("relay_mode")
 | 
			
		||||
	var err *dto.MidjourneyResponse
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case relayconstant.RelayModeMidjourneyNotify:
 | 
			
		||||
@@ -73,13 +73,15 @@ func RelayMidjourney(c *gin.Context) {
 | 
			
		||||
	//err = relayMidjourneySubmit(c, relayMode)
 | 
			
		||||
	log.Println(err)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		statusCode := http.StatusBadRequest
 | 
			
		||||
		if err.Code == 30 {
 | 
			
		||||
			err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
 | 
			
		||||
			statusCode = http.StatusTooManyRequests
 | 
			
		||||
		}
 | 
			
		||||
		c.JSON(429, gin.H{
 | 
			
		||||
			"error": fmt.Sprintf("%s %s", err.Description, err.Result),
 | 
			
		||||
			"type":  "upstream_error",
 | 
			
		||||
			"code":  err.Code,
 | 
			
		||||
		c.JSON(statusCode, gin.H{
 | 
			
		||||
			"description": fmt.Sprintf("%s %s", err.Description, err.Result),
 | 
			
		||||
			"type":        "upstream_error",
 | 
			
		||||
			"code":        err.Code,
 | 
			
		||||
		})
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
//type SimpleMjRequest struct {
 | 
			
		||||
//	Prompt   string `json:"prompt"`
 | 
			
		||||
//	CustomId string `json:"customId"`
 | 
			
		||||
//	Action   string `json:"action"`
 | 
			
		||||
//	Content  string `json:"content"`
 | 
			
		||||
//}
 | 
			
		||||
 | 
			
		||||
type MidjourneyRequest struct {
 | 
			
		||||
	Prompt      string   `json:"prompt"`
 | 
			
		||||
	CustomId    string   `json:"customId"`
 | 
			
		||||
 
 | 
			
		||||
@@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 | 
			
		||||
			abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		userEnabled, err := model.CacheIsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 | 
			
		||||
			abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled {
 | 
			
		||||
			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 | 
			
		||||
			abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("id", token.UserId)
 | 
			
		||||
@@ -129,7 +129,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			if model.IsAdmin(token.UserId) {
 | 
			
		||||
				c.Set("channelId", parts[1])
 | 
			
		||||
			} else {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
			
		||||
				abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,11 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
@@ -23,32 +27,58 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.GetChannelById(id, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
			
		||||
				abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			shouldSelectChannel := true
 | 
			
		||||
			// Select a channel for the user
 | 
			
		||||
			var modelRequest ModelRequest
 | 
			
		||||
			var err error
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/mj") {
 | 
			
		||||
				// Midjourney
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "midjourney"
 | 
			
		||||
				relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
 | 
			
		||||
				if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
 | 
			
		||||
					relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
 | 
			
		||||
					relayMode == relayconstant.RelayModeMidjourneyNotify {
 | 
			
		||||
					shouldSelectChannel = false
 | 
			
		||||
				} else {
 | 
			
		||||
					midjourneyRequest := dto.MidjourneyRequest{}
 | 
			
		||||
					err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
 | 
			
		||||
					if mjErr != nil {
 | 
			
		||||
						abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					if midjourneyModel == "" {
 | 
			
		||||
						if !success {
 | 
			
		||||
							abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
 | 
			
		||||
							return
 | 
			
		||||
						} else {
 | 
			
		||||
							// task fetch, task fetch by condition, notify
 | 
			
		||||
							shouldSelectChannel = false
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					modelRequest.Model = midjourneyModel
 | 
			
		||||
				}
 | 
			
		||||
				c.Set("relay_mode", relayMode)
 | 
			
		||||
			} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 | 
			
		||||
				err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 | 
			
		||||
				abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 | 
			
		||||
@@ -87,60 +117,61 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				}
 | 
			
		||||
				if tokenModelLimit != nil {
 | 
			
		||||
					if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
 | 
			
		||||
						abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
 | 
			
		||||
						abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					// token model limit is empty, all models are not allowed
 | 
			
		||||
					abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
 | 
			
		||||
					abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
			c.Set("group", userGroup)
 | 
			
		||||
 | 
			
		||||
			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
			
		||||
				// 如果错误,但是渠道不为空,说明是数据库一致性问题
 | 
			
		||||
				if channel != nil {
 | 
			
		||||
					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
			
		||||
					message = "数据库一致性已被破坏,请联系管理员"
 | 
			
		||||
			if shouldSelectChannel {
 | 
			
		||||
				channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
			
		||||
					// 如果错误,但是渠道不为空,说明是数据库一致性问题
 | 
			
		||||
					if channel != nil {
 | 
			
		||||
						common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 | 
			
		||||
						message = "数据库一致性已被破坏,请联系管理员"
 | 
			
		||||
					}
 | 
			
		||||
					// 如果错误,而且渠道为空,说明是没有可用渠道
 | 
			
		||||
					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				if channel == nil {
 | 
			
		||||
					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				c.Set("channel", channel.Type)
 | 
			
		||||
				c.Set("channel_id", channel.Id)
 | 
			
		||||
				c.Set("channel_name", channel.Name)
 | 
			
		||||
				ban := true
 | 
			
		||||
				// parse *int to bool
 | 
			
		||||
				if channel.AutoBan != nil && *channel.AutoBan == 0 {
 | 
			
		||||
					ban = false
 | 
			
		||||
				}
 | 
			
		||||
				c.Set("auto_ban", ban)
 | 
			
		||||
				c.Set("model_mapping", channel.GetModelMapping())
 | 
			
		||||
				c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
				c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
				// TODO: api_version统一
 | 
			
		||||
				switch channel.Type {
 | 
			
		||||
				case common.ChannelTypeAzure:
 | 
			
		||||
					c.Set("api_version", channel.Other)
 | 
			
		||||
				case common.ChannelTypeXunfei:
 | 
			
		||||
					c.Set("api_version", channel.Other)
 | 
			
		||||
				//case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
				//	c.Set("library_id", channel.Other)
 | 
			
		||||
				case common.ChannelTypeGemini:
 | 
			
		||||
					c.Set("api_version", channel.Other)
 | 
			
		||||
				case common.ChannelTypeAli:
 | 
			
		||||
					c.Set("plugin", channel.Other)
 | 
			
		||||
				}
 | 
			
		||||
				// 如果错误,而且渠道为空,说明是没有可用渠道
 | 
			
		||||
				abortWithMessage(c, http.StatusServiceUnavailable, message)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel == nil {
 | 
			
		||||
				abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("channel", channel.Type)
 | 
			
		||||
		c.Set("channel_id", channel.Id)
 | 
			
		||||
		c.Set("channel_name", channel.Name)
 | 
			
		||||
		ban := true
 | 
			
		||||
		// parse *int to bool
 | 
			
		||||
		if channel.AutoBan != nil && *channel.AutoBan == 0 {
 | 
			
		||||
			ban = false
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("auto_ban", ban)
 | 
			
		||||
		c.Set("model_mapping", channel.GetModelMapping())
 | 
			
		||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
		c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
		// TODO: api_version统一
 | 
			
		||||
		switch channel.Type {
 | 
			
		||||
		case common.ChannelTypeAzure:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		case common.ChannelTypeXunfei:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		//case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
		//	c.Set("library_id", channel.Other)
 | 
			
		||||
		case common.ChannelTypeGemini:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		case common.ChannelTypeAli:
 | 
			
		||||
			c.Set("plugin", channel.Other)
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
	c.JSON(statusCode, gin.H{
 | 
			
		||||
		"error": gin.H{
 | 
			
		||||
			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
 | 
			
		||||
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
	c.Abort()
 | 
			
		||||
	common.LogError(c.Request.Context(), message)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
 | 
			
		||||
	c.JSON(statusCode, gin.H{
 | 
			
		||||
		"description": description,
 | 
			
		||||
		"type":        "new_api_error",
 | 
			
		||||
		"code":        code,
 | 
			
		||||
	})
 | 
			
		||||
	c.Abort()
 | 
			
		||||
	common.LogError(c.Request.Context(), description)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,23 +21,6 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var DefaultModelPrice = map[string]float64{
 | 
			
		||||
	"mj_imagine":        0.1,
 | 
			
		||||
	"mj_variation":      0.1,
 | 
			
		||||
	"mj_reroll":         0.1,
 | 
			
		||||
	"mj_blend":          0.1,
 | 
			
		||||
	"mj_inpaint":        0.1,
 | 
			
		||||
	"mj_zoom":           0.1,
 | 
			
		||||
	"mj_shorten":        0.1,
 | 
			
		||||
	"mj_high_variation": 0.1,
 | 
			
		||||
	"mj_low_variation":  0.1,
 | 
			
		||||
	"mj_pan":            0.1,
 | 
			
		||||
	"mj_inpaint_pre":    0,
 | 
			
		||||
	"mj_describe":       0.05,
 | 
			
		||||
	"mj_upscale":        0.05,
 | 
			
		||||
	"swap_face":         0.05,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayMidjourneyImage(c *gin.Context) {
 | 
			
		||||
	taskId := c.Param("id")
 | 
			
		||||
	midjourneyTask := model.GetByOnlyMJId(taskId)
 | 
			
		||||
@@ -221,10 +204,9 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
 | 
			
		||||
	imageModel := "midjourney"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	//channelType := c.GetInt("channel")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
@@ -236,7 +218,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
 | 
			
		||||
		mjErr := coverPlusActionToNormalAction(&midjRequest)
 | 
			
		||||
		mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
 | 
			
		||||
		if mjErr != nil {
 | 
			
		||||
			return mjErr
 | 
			
		||||
		}
 | 
			
		||||
@@ -270,11 +252,11 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 | 
			
		||||
			if midjRequest.Content == "" {
 | 
			
		||||
				return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
 | 
			
		||||
			}
 | 
			
		||||
			params := convertSimpleChangeParams(midjRequest.Content)
 | 
			
		||||
			params := service.ConvertSimpleChangeParams(midjRequest.Content)
 | 
			
		||||
			if params == nil {
 | 
			
		||||
				return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
 | 
			
		||||
			}
 | 
			
		||||
			mjId = params.ID
 | 
			
		||||
			mjId = params.TaskId
 | 
			
		||||
			midjRequest.Action = params.Action
 | 
			
		||||
		} else if relayMode == relayconstant.RelayModeMidjourneyModal {
 | 
			
		||||
			if midjRequest.MaskBase64 == "" {
 | 
			
		||||
@@ -294,18 +276,21 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
				return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
 | 
			
		||||
			}
 | 
			
		||||
			c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
			c.Set("channel_id", originTask.ChannelId)
 | 
			
		||||
			log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
 | 
			
		||||
		}
 | 
			
		||||
		midjRequest.Prompt = originTask.Prompt
 | 
			
		||||
 | 
			
		||||
		if channelType == common.ChannelTypeMidjourneyPlus {
 | 
			
		||||
			// plus
 | 
			
		||||
		} else {
 | 
			
		||||
			// 普通版渠道
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
		//if channelType == common.ChannelTypeMidjourneyPlus {
 | 
			
		||||
		//	// plus
 | 
			
		||||
		//} else {
 | 
			
		||||
		//	// 普通版渠道
 | 
			
		||||
		//
 | 
			
		||||
		//}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if midjRequest.Action == constant.MjActionInPaintPre {
 | 
			
		||||
@@ -313,54 +298,52 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	isModelMapped := false
 | 
			
		||||
	if modelMapping != "" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "unmarshal_model_mapping_failed",
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap[imageModel] != "" {
 | 
			
		||||
			imageModel = modelMap[imageModel]
 | 
			
		||||
			isModelMapped = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	//modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	//isModelMapped := false
 | 
			
		||||
	//if modelMapping != "" {
 | 
			
		||||
	//	modelMap := make(map[string]string)
 | 
			
		||||
	//	err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
	//	if err != nil {
 | 
			
		||||
	//		//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
	//		return &dto.MidjourneyResponse{
 | 
			
		||||
	//			Code:        4,
 | 
			
		||||
	//			Description: "unmarshal_model_mapping_failed",
 | 
			
		||||
	//		}
 | 
			
		||||
	//	}
 | 
			
		||||
	//	if modelMap[imageModel] != "" {
 | 
			
		||||
	//		imageModel = modelMap[imageModel]
 | 
			
		||||
	//		isModelMapped = true
 | 
			
		||||
	//	}
 | 
			
		||||
	//}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	//baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
	baseURL := c.GetString("base_url")
 | 
			
		||||
 | 
			
		||||
	//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
 | 
			
		||||
	var requestBody io.Reader
 | 
			
		||||
	if isModelMapped {
 | 
			
		||||
		jsonStr, err := json.Marshal(midjRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return &dto.MidjourneyResponse{
 | 
			
		||||
				Code:        4,
 | 
			
		||||
				Description: "marshal_text_request_failed",
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	} else {
 | 
			
		||||
		requestBody = c.Request.Body
 | 
			
		||||
	}
 | 
			
		||||
	//if isModelMapped {
 | 
			
		||||
	//	jsonStr, err := json.Marshal(midjRequest)
 | 
			
		||||
	//	if err != nil {
 | 
			
		||||
	//		return &dto.MidjourneyResponse{
 | 
			
		||||
	//			Code:        4,
 | 
			
		||||
	//			Description: "marshal_text_request_failed",
 | 
			
		||||
	//		}
 | 
			
		||||
	//	}
 | 
			
		||||
	//	requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	//} else {
 | 
			
		||||
	//}
 | 
			
		||||
	requestBody = c.Request.Body
 | 
			
		||||
 | 
			
		||||
	mjAction := "mj_" + strings.ToLower(midjRequest.Action)
 | 
			
		||||
	modelPrice := common.GetModelPrice(mjAction, true)
 | 
			
		||||
	modelName := service.CoverActionToModelName(midjRequest.Action)
 | 
			
		||||
	modelPrice := common.GetModelPrice(modelName, true)
 | 
			
		||||
	// 如果没有配置价格,则使用默认价格
 | 
			
		||||
	if modelPrice == -1 {
 | 
			
		||||
		defaultPrice, ok := DefaultModelPrice[mjAction]
 | 
			
		||||
		defaultPrice, ok := common.DefaultModelPrice[modelName]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			modelPrice = 0.1
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -433,7 +416,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
 | 
			
		||||
				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
@@ -558,85 +541,3 @@ type taskChangeParams struct {
 | 
			
		||||
	Action string
 | 
			
		||||
	Index  int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertSimpleChangeParams(content string) *taskChangeParams {
 | 
			
		||||
	split := strings.Split(content, " ")
 | 
			
		||||
	if len(split) != 2 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	action := strings.ToLower(split[1])
 | 
			
		||||
	changeParams := &taskChangeParams{}
 | 
			
		||||
	changeParams.ID = split[0]
 | 
			
		||||
 | 
			
		||||
	if action[0] == 'u' {
 | 
			
		||||
		changeParams.Action = "UPSCALE"
 | 
			
		||||
	} else if action[0] == 'v' {
 | 
			
		||||
		changeParams.Action = "VARIATION"
 | 
			
		||||
	} else if action == "r" {
 | 
			
		||||
		changeParams.Action = "REROLL"
 | 
			
		||||
		return changeParams
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	index, err := strconv.Atoi(action[1:2])
 | 
			
		||||
	if err != nil || index < 1 || index > 4 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	changeParams.Index = index
 | 
			
		||||
	return changeParams
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func coverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
 | 
			
		||||
	// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
 | 
			
		||||
	customId := midjRequest.CustomId
 | 
			
		||||
	if customId == "" {
 | 
			
		||||
		return service.MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
 | 
			
		||||
	}
 | 
			
		||||
	splits := strings.Split(customId, "::")
 | 
			
		||||
	var action string
 | 
			
		||||
	if splits[1] == "JOB" {
 | 
			
		||||
		action = splits[2]
 | 
			
		||||
	} else {
 | 
			
		||||
		action = splits[1]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if action == "" {
 | 
			
		||||
		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(action, "upsample") {
 | 
			
		||||
		index, err := strconv.Atoi(splits[3])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
 | 
			
		||||
		}
 | 
			
		||||
		midjRequest.Index = index
 | 
			
		||||
		midjRequest.Action = constant.MjActionUpscale
 | 
			
		||||
	} else if strings.Contains(action, "variation") {
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
		if action == "variation" {
 | 
			
		||||
			index, err := strconv.Atoi(splits[3])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
 | 
			
		||||
			}
 | 
			
		||||
			midjRequest.Index = index
 | 
			
		||||
			midjRequest.Action = constant.MjActionVariation
 | 
			
		||||
		} else if action == "low_variation" {
 | 
			
		||||
			midjRequest.Action = constant.MjActionLowVariation
 | 
			
		||||
		} else if action == "high_variation" {
 | 
			
		||||
			midjRequest.Action = constant.MjActionHighVariation
 | 
			
		||||
		}
 | 
			
		||||
	} else if strings.Contains(action, "pan") {
 | 
			
		||||
		midjRequest.Action = constant.MjActionPan
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
	} else if action == "Outpaint" || action == "CustomZoom" {
 | 
			
		||||
		midjRequest.Action = constant.MjActionZoom
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
	} else if action == "Inpaint" {
 | 
			
		||||
		midjRequest.Action = constant.MjActionInPaintPre
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										135
									
								
								service/midjourney.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								service/midjourney.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,135 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
	"one-api/dto"
 | 
			
		||||
	relayconstant "one-api/relay/constant"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CoverActionToModelName(mjAction string) string {
 | 
			
		||||
	modelName := "mj_" + strings.ToLower(mjAction)
 | 
			
		||||
	return modelName
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
 | 
			
		||||
	action := ""
 | 
			
		||||
	if relayMode == relayconstant.RelayModeMidjourneyAction {
 | 
			
		||||
		// plus request
 | 
			
		||||
		err := CoverPlusActionToNormalAction(midjRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err, false
 | 
			
		||||
		}
 | 
			
		||||
		action = midjRequest.Action
 | 
			
		||||
	} else {
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyImagine:
 | 
			
		||||
			action = constant.MjActionImagine
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyDescribe:
 | 
			
		||||
			action = constant.MjActionDescribe
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyBlend:
 | 
			
		||||
			action = constant.MjActionBlend
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyShorten:
 | 
			
		||||
			action = constant.MjActionShorten
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyChange:
 | 
			
		||||
			action = midjRequest.Action
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyModal:
 | 
			
		||||
			action = constant.MjActionInPaint
 | 
			
		||||
		case relayconstant.RelayModeMidjourneySimpleChange:
 | 
			
		||||
			params := ConvertSimpleChangeParams(midjRequest.Content)
 | 
			
		||||
			if params == nil {
 | 
			
		||||
				return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
 | 
			
		||||
			}
 | 
			
		||||
			action = params.Action
 | 
			
		||||
		case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
 | 
			
		||||
			return "", nil, true
 | 
			
		||||
		default:
 | 
			
		||||
			return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	modelName := CoverActionToModelName(action)
 | 
			
		||||
	return modelName, nil, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
 | 
			
		||||
	// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
 | 
			
		||||
	customId := midjRequest.CustomId
 | 
			
		||||
	if customId == "" {
 | 
			
		||||
		return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
 | 
			
		||||
	}
 | 
			
		||||
	splits := strings.Split(customId, "::")
 | 
			
		||||
	var action string
 | 
			
		||||
	if splits[1] == "JOB" {
 | 
			
		||||
		action = splits[2]
 | 
			
		||||
	} else {
 | 
			
		||||
		action = splits[1]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if action == "" {
 | 
			
		||||
		return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(action, "upsample") {
 | 
			
		||||
		index, err := strconv.Atoi(splits[3])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
 | 
			
		||||
		}
 | 
			
		||||
		midjRequest.Index = index
 | 
			
		||||
		midjRequest.Action = constant.MjActionUpscale
 | 
			
		||||
	} else if strings.Contains(action, "variation") {
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
		if action == "variation" {
 | 
			
		||||
			index, err := strconv.Atoi(splits[3])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
 | 
			
		||||
			}
 | 
			
		||||
			midjRequest.Index = index
 | 
			
		||||
			midjRequest.Action = constant.MjActionVariation
 | 
			
		||||
		} else if action == "low_variation" {
 | 
			
		||||
			midjRequest.Action = constant.MjActionLowVariation
 | 
			
		||||
		} else if action == "high_variation" {
 | 
			
		||||
			midjRequest.Action = constant.MjActionHighVariation
 | 
			
		||||
		}
 | 
			
		||||
	} else if strings.Contains(action, "pan") {
 | 
			
		||||
		midjRequest.Action = constant.MjActionPan
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
	} else if action == "Outpaint" || action == "CustomZoom" {
 | 
			
		||||
		midjRequest.Action = constant.MjActionZoom
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
	} else if action == "Inpaint" {
 | 
			
		||||
		midjRequest.Action = constant.MjActionInPaintPre
 | 
			
		||||
		midjRequest.Index = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
 | 
			
		||||
	split := strings.Split(content, " ")
 | 
			
		||||
	if len(split) != 2 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	action := strings.ToLower(split[1])
 | 
			
		||||
	changeParams := &dto.MidjourneyRequest{}
 | 
			
		||||
	changeParams.TaskId = split[0]
 | 
			
		||||
 | 
			
		||||
	if action[0] == 'u' {
 | 
			
		||||
		changeParams.Action = "UPSCALE"
 | 
			
		||||
	} else if action[0] == 'v' {
 | 
			
		||||
		changeParams.Action = "VARIATION"
 | 
			
		||||
	} else if action == "r" {
 | 
			
		||||
		changeParams.Action = "REROLL"
 | 
			
		||||
		return changeParams
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	index, err := strconv.Atoi(action[1:2])
 | 
			
		||||
	if err != nil || index < 1 || index > 4 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	changeParams.Index = index
 | 
			
		||||
	return changeParams
 | 
			
		||||
}
 | 
			
		||||
@@ -96,10 +96,25 @@ const EditChannel = (props) => {
 | 
			
		||||
                    localModels = ['glm-4', 'glm-4v', 'glm-3-turbo'];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 2:
 | 
			
		||||
                    localModels = ['midjourney'];
 | 
			
		||||
                    localModels = ['mj_imagine', 'mj_variation', 'mj_reroll', 'mj_blend', 'mj_upscale', 'mj_describe'];
 | 
			
		||||
                    break;
 | 
			
		||||
                case 5:
 | 
			
		||||
                    localModels = ['midjourney'];
 | 
			
		||||
                    localModels = [
 | 
			
		||||
                        'swap_face',
 | 
			
		||||
                        'mj_imagine',
 | 
			
		||||
                        'mj_variation',
 | 
			
		||||
                        'mj_reroll',
 | 
			
		||||
                        'mj_blend',
 | 
			
		||||
                        'mj_upscale',
 | 
			
		||||
                        'mj_describe',
 | 
			
		||||
                        'mj_zoom',
 | 
			
		||||
                        'mj_shorten',
 | 
			
		||||
                        'mj_inpaint_pre',
 | 
			
		||||
                        'mj_inpaint_pre',
 | 
			
		||||
                        'mj_high_variation',
 | 
			
		||||
                        'mj_low_variation',
 | 
			
		||||
                        'mj_pan',
 | 
			
		||||
                    ];
 | 
			
		||||
                    break;
 | 
			
		||||
            }
 | 
			
		||||
            setInputs((inputs) => ({...inputs, models: localModels}));
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user