feat: add Midjourney (#138)

* 🚧 stash

*  feat: add Midjourney

* 📝 doc: update readme
This commit is contained in:
Buer
2024-04-05 04:03:46 +08:00
committed by GitHub
parent 87bfecf3e9
commit c1fc32add7
42 changed files with 2479 additions and 84 deletions

View File

@@ -27,7 +27,7 @@ type RelayBaseInterface interface {
}
func (r *relayBase) setProvider(modelName string) error {
provider, modelName, fail := getProvider(r.c, modelName)
provider, modelName, fail := GetProvider(r.c, modelName)
if fail != nil {
return fail
}

View File

@@ -45,7 +45,7 @@ func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
return nil
}
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
func GetProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
channel, fail := fetchChannel(c, modeName)
if fail != nil {
return

19
relay/midjourney/LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2024 Calcium-Ion
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,578 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: relay/relay-mj.go
package midjourney
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"one-api/common"
"one-api/controller"
"one-api/model"
providersBase "one-api/providers/base"
provider "one-api/providers/midjourney"
"one-api/relay"
"one-api/relay/util"
"one-api/types"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
c.JSON(resp.StatusCode, gin.H{
"error": string(responseBody),
})
return
}
// 从Content-Type头获取MIME类型
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// 如果无法确定内容类型则默认为jpeg
contentType = "image/jpeg"
}
// 设置响应的内容类型
c.Writer.Header().Set("Content-Type", contentType)
// 将图片流式传输到响应体
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println("Failed to stream image:", err)
}
}
func RelayMidjourneyNotify(c *gin.Context) *provider.MidjourneyResponse {
var midjRequest provider.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provider.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
if originTask.Buttons != "" {
var buttons []provider.ActionButton
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
if err == nil {
midjourneyTask.Buttons = buttons
}
}
if originTask.Properties != "" {
var properties provider.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
if err == nil {
midjourneyTask.Properties = &properties
}
}
return
}
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneySwapFace, 0, nil)
if errWithMJ != nil {
return errWithMJ
}
startTime := time.Now().UnixNano() / int64(time.Millisecond)
userId := c.GetInt("id")
var swapFaceRequest provider.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
}
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: errWithOA.Message,
}
}
requestURL := getMjRequestPath(c.Request.URL.String())
mjResp, _, err := mjProvider.Send(60, requestURL)
if err != nil {
quotaInstance.Undo(c)
return &mjResp.Response
}
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1000, TotalTokens: 1000})
} else {
quotaInstance.Undo(c)
}
quota := int(quotaInstance.GetInputRatio() * 1000)
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: provider.MjActionSwapFace,
MjId: midjResponse.Result,
Prompt: "InsightFace",
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: startTime,
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
err = midjourneyTask.Insert()
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "insert_midjourney_task_failed")
}
// 开始激活任务
controller.ActivateUpdateMidjourneyTaskBulk()
c.Writer.WriteHeader(mjResp.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
}
mjProvider, errWithMJ := getMJProvider(c, provider.RelayModeMidjourneyTaskImageSeed, originTask.ChannelId, nil)
if errWithMJ != nil {
return errWithMJ
}
requestURL := getMjRequestPath(c.Request.URL.String())
midjResponseWithStatus, _, err := mjProvider.Send(30, requestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
userId := c.GetInt("id")
var err error
var respBody []byte
switch relayMode {
case provider.RelayModeMidjourneyTaskFetch:
taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
midjourneyTask := coverMidjourneyTaskDto(originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
case provider.RelayModeMidjourneyTaskFetchByCondition:
var condition = struct {
IDs []string `json:"ids"`
}{}
err = c.BindJSON(&condition)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
var tasks []provider.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := coverMidjourneyTaskDto(originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]provider.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
channelId := 0
userId := c.GetInt("id")
consumeQuota := true
var midjRequest provider.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
mjErr := CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
relayMode = provider.RelayModeMidjourneyChange
}
if relayMode == provider.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "prompt_is_required")
}
midjRequest.Action = provider.MjActionImagine
} else if relayMode == provider.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = provider.MjActionDescribe
} else if relayMode == provider.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
midjRequest.Action = provider.MjActionShorten
} else if relayMode == provider.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = provider.MjActionBlend
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == provider.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == provider.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_is_required")
}
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_parse_failed")
}
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == provider.RelayModeMidjourneyModal {
//if midjRequest.MaskBase64 == "" {
// return provider.MidjourneyErrorWrapper(provider.MjRequestError, "mask_base64_is_required")
//}
mjId = midjRequest.TaskId
midjRequest.Action = provider.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_not_found")
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channelId = originTask.ChannelId
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %d", originTask.ChannelId)
}
midjRequest.Prompt = originTask.Prompt
//if channelType == common.ChannelTypeMidjourneyPlus {
// // plus
//} else {
// // 普通版渠道
//
//}
}
if midjRequest.Action == provider.MjActionInPaint || midjRequest.Action == provider.MjActionCustomZoom {
consumeQuota = false
}
mjProvider, errWithMJ := getMJProvider(c, relayMode, channelId, &midjRequest)
if errWithMJ != nil {
return errWithMJ
}
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := getMjRequestPath(c.Request.URL.String())
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
quotaInstance, errWithOA := getQuota(c, mjProvider.GetOriginalModel())
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: errWithOA.Message,
}
}
midjResponseWithStatus, responseBody, err := mjProvider.Send(60, requestURL)
if err != nil {
quotaInstance.Undo(c)
return &midjResponseWithStatus.Response
}
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
} else {
quotaInstance.Undo(c)
}
quota := int(quotaInstance.GetInputRatio() * 1000)
midjResponse := &midjResponseWithStatus.Response
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
// 22-排队中 {"code":22,"description":"排队中前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description
consumeQuota = false
}
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
// 将 properties 转换为一个 map
properties, ok := midjResponse.Properties.(map[string]interface{})
if ok {
imageUrl, ok1 := properties["imageUrl"].(string)
status, ok2 := properties["status"].(string)
if ok1 && ok2 {
midjourneyTask.ImageUrl = imageUrl
midjourneyTask.Status = status
if status == "SUCCESS" {
midjourneyTask.Progress = "100%"
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
midjResponse.Code = 1
}
}
}
//修改返回值
if midjRequest.Action != provider.MjActionInPaint && midjRequest.Action != provider.MjActionCustomZoom {
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
}
err = midjourneyTask.Insert()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
// 开始激活任务
controller.ActivateUpdateMidjourneyTaskBulk()
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
_, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = bodyReader.Close()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}
func getMjRequestPath(path string) string {
requestURL := path
if strings.Contains(requestURL, "/mj-") {
urls := strings.Split(requestURL, "/mj/")
if len(urls) < 2 {
return requestURL
}
requestURL = "/mj/" + urls[1]
}
return requestURL
}
func getQuota(c *gin.Context, modelName string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
// modelName = CoverActionToModelName(modelName)
return util.NewQuota(c, modelName, 1000)
}
func getMJProvider(c *gin.Context, relayMode, channel_id int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
var baseProvider providersBase.ProviderInterface
modelName := ""
if channel_id > 0 {
c.Set("specific_channel_id", channel_id)
}
if request != nil {
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
if mjErr != nil {
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
}
if midjourneyModel == "" {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
}
modelName = midjourneyModel
}
var err error
baseProvider, _, err = relay.GetProvider(c, modelName)
if err != nil {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
}
mjProvider, ok := baseProvider.(*provider.MidjourneyProvider)
if !ok {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法获取midjourney provider")
}
return mjProvider, nil
}

95
relay/midjourney/relay.go Normal file
View File

@@ -0,0 +1,95 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: controller/relay.go
package midjourney
import (
"fmt"
"net/http"
"one-api/common"
provider "one-api/providers/midjourney"
"strings"
"github.com/gin-gonic/gin"
)
func RelayMidjourney(c *gin.Context) {
relayMode := Path2RelayModeMidjourney(c.Request.URL.Path)
var err *provider.MidjourneyResponse
switch relayMode {
case provider.RelayModeMidjourneyNotify:
err = RelayMidjourneyNotify(c)
case provider.RelayModeMidjourneyTaskFetch, provider.RelayModeMidjourneyTaskFetchByCondition:
err = RelayMidjourneyTask(c, relayMode)
case provider.RelayModeMidjourneyTaskImageSeed:
err = RelayMidjourneyTaskImageSeed(c)
case provider.RelayModeMidjourneySwapFace:
err = RelaySwapFace(c)
default:
err = RelayMidjourneySubmit(c, relayMode)
}
if err != nil {
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
}
typeMsg := "upstream_error"
if err.Type != "" {
typeMsg = err.Type
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": typeMsg,
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
}
}
func MidjourneyErrorFromInternal(code int, description string) *provider.MidjourneyResponse {
return &provider.MidjourneyResponse{
Code: code,
Description: description,
Type: "internal_error",
}
}
func Path2RelayModeMidjourney(path string) int {
relayMode := provider.RelayModeUnknown
if strings.HasSuffix(path, "/mj/submit/action") {
// midjourney plus
relayMode = provider.RelayModeMidjourneyAction
} else if strings.HasSuffix(path, "/mj/submit/modal") {
// midjourney plus
relayMode = provider.RelayModeMidjourneyModal
} else if strings.HasSuffix(path, "/mj/submit/shorten") {
// midjourney plus
relayMode = provider.RelayModeMidjourneyShorten
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = provider.RelayModeMidjourneySwapFace
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = provider.RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") {
relayMode = provider.RelayModeMidjourneyBlend
} else if strings.HasSuffix(path, "/mj/submit/describe") {
relayMode = provider.RelayModeMidjourneyDescribe
} else if strings.HasSuffix(path, "/mj/notify") {
relayMode = provider.RelayModeMidjourneyNotify
} else if strings.HasSuffix(path, "/mj/submit/change") {
relayMode = provider.RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/mj/submit/simple-change") {
relayMode = provider.RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/fetch") {
relayMode = provider.RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(path, "/image-seed") {
relayMode = provider.RelayModeMidjourneyTaskImageSeed
} else if strings.HasSuffix(path, "/list-by-condition") {
relayMode = provider.RelayModeMidjourneyTaskFetchByCondition
}
return relayMode
}

148
relay/midjourney/service.go Normal file
View File

@@ -0,0 +1,148 @@
// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: service/midjourney.go
package midjourney
import (
mjProvider "one-api/providers/midjourney"
"strconv"
"strings"
)
func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
if mjAction == mjProvider.MjActionSwapFace {
modelName = "swap_face"
}
return modelName
}
func GetMjRequestModel(relayMode int, midjRequest *mjProvider.MidjourneyRequest) (string, *mjProvider.MidjourneyResponse, bool) {
action := ""
if relayMode == mjProvider.RelayModeMidjourneyAction {
// plus request
err := CoverPlusActionToNormalAction(midjRequest)
if err != nil {
return "", err, false
}
action = midjRequest.Action
} else {
switch relayMode {
case mjProvider.RelayModeMidjourneyImagine:
action = mjProvider.MjActionImagine
case mjProvider.RelayModeMidjourneyDescribe:
action = mjProvider.MjActionDescribe
case mjProvider.RelayModeMidjourneyBlend:
action = mjProvider.MjActionBlend
case mjProvider.RelayModeMidjourneyShorten:
action = mjProvider.MjActionShorten
case mjProvider.RelayModeMidjourneyChange:
action = midjRequest.Action
case mjProvider.RelayModeMidjourneyModal:
action = mjProvider.MjActionModal
case mjProvider.RelayModeMidjourneySwapFace:
action = mjProvider.MjActionSwapFace
case mjProvider.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "invalid_request"), false
}
action = params.Action
case mjProvider.RelayModeMidjourneyTaskFetch, mjProvider.RelayModeMidjourneyTaskFetchByCondition, mjProvider.RelayModeMidjourneyNotify:
return "", nil, true
default:
return "", mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_relay_action"), false
}
}
modelName := CoverActionToModelName(action)
return modelName, nil, true
}
func CoverPlusActionToNormalAction(midjRequest *mjProvider.MidjourneyRequest) *mjProvider.MidjourneyResponse {
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
customId := midjRequest.CustomId
if customId == "" {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "custom_id_is_required")
}
splits := strings.Split(customId, "::")
var action string
if splits[1] == "JOB" {
action = splits[2]
} else {
action = splits[1]
}
if action == "" {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action")
}
if strings.Contains(action, "upsample") {
index, err := strconv.Atoi(splits[3])
if err != nil {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = mjProvider.MjActionUpscale
} else if strings.Contains(action, "variation") {
midjRequest.Index = 1
if action == "variation" {
index, err := strconv.Atoi(splits[3])
if err != nil {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = mjProvider.MjActionVariation
} else if action == "low_variation" {
midjRequest.Action = mjProvider.MjActionLowVariation
} else if action == "high_variation" {
midjRequest.Action = mjProvider.MjActionHighVariation
}
} else if strings.Contains(action, "pan") {
midjRequest.Action = mjProvider.MjActionPan
midjRequest.Index = 1
} else if strings.Contains(action, "reroll") {
midjRequest.Action = mjProvider.MjActionReRoll
midjRequest.Index = 1
} else if action == "Outpaint" {
midjRequest.Action = mjProvider.MjActionZoom
midjRequest.Index = 1
} else if action == "CustomZoom" {
midjRequest.Action = mjProvider.MjActionCustomZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
midjRequest.Action = mjProvider.MjActionInPaint
midjRequest.Index = 1
} else {
return mjProvider.MidjourneyErrorWrapper(mjProvider.MjRequestError, "unknown_action:"+customId)
}
return nil
}
func ConvertSimpleChangeParams(content string) *mjProvider.MidjourneyRequest {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &mjProvider.MidjourneyRequest{}
changeParams.TaskId = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}

View File

@@ -170,3 +170,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) {
}
}(c.Request.Context())
}
func (q *Quota) GetInputRatio() float64 {
return q.inputRatio
}

View File

@@ -7,22 +7,23 @@ var ModelOwnedBy map[int]string
func init() {
ModelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
common.ChannelTypeMidjourney: "Midjourney",
}
}